Skip to content

Commit

Permalink
fix: umount issue leaves global lock (#95)
Browse files Browse the repository at this point in the history
We saw an issue where one Unmount operation suspended, and because it
had the `volumeLock` locked, subsequent operations also all suspended
because they could not obtain the lock.

We have refactored locking access to the `volumes` map so that fewer
operations happen under lock, and a deadlock is now less likely.

Co-authored-by: Konstantin Kiess <kkiess@vmware.com>
  • Loading branch information
blgm and Konstantin Kiess committed Jul 6, 2023
1 parent 8916756 commit 39b2964
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 61 deletions.
73 changes: 73 additions & 0 deletions internal/syncmap/syncmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package syncmap

import (
"encoding/json"
"sync"
)

func New[A any]() *SyncMap[A] {
return &SyncMap[A]{data: make(map[string]A)}
}

type SyncMap[A any] struct {
data map[string]A
lock sync.RWMutex
}

func (s *SyncMap[A]) Put(key string, value A) {
s.lock.Lock()
defer s.lock.Unlock()

s.data[key] = value
}

func (s *SyncMap[A]) Get(key string) (A, bool) {
s.lock.RLock()
defer s.lock.RUnlock()

val, ok := s.data[key]
return val, ok
}

func (s *SyncMap[A]) Delete(key string) {
s.lock.Lock()
defer s.lock.Unlock()

delete(s.data, key)
}

func (s *SyncMap[A]) MarshalJSON() ([]byte, error) {
s.lock.RLock()
defer s.lock.RUnlock()

return json.Marshal(s.data)
}

func (s *SyncMap[A]) UnmarshalJSON(data []byte) error {
s.lock.Lock()
defer s.lock.Unlock()

return json.Unmarshal(data, &s.data)
}

func (s *SyncMap[A]) Keys() []string {
s.lock.RLock()
defer s.lock.RUnlock()

result := make([]string, 0, len(s.data))
for key := range s.data {
result = append(result, key)
}
return result
}

func (s *SyncMap[A]) Values() []A {
s.lock.RLock()
defer s.lock.RUnlock()

result := make([]A, 0, len(s.data))
for _, value := range s.data {
result = append(result, value)
}
return result
}
13 changes: 13 additions & 0 deletions internal/syncmap/syncmap_suite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package syncmap_test

import (
"testing"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

func TestSyncmap(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Syncmap Suite")
}
94 changes: 94 additions & 0 deletions internal/syncmap/syncmap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package syncmap_test

import (
"encoding/json"
"fmt"
"sync"
"time"

"code.cloudfoundry.org/volumedriver/internal/syncmap"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("SyncMap", func() {
It("can put and get concurrently", func() {
s := syncmap.New[int]()

const workers = 100
var wg sync.WaitGroup
wg.Add(workers)
for i := 0; i < workers; i++ {
go func(workerID int) {
defer GinkgoRecover()
defer wg.Done()

key := fmt.Sprintf("%d", workerID)
s.Put(key, workerID)
time.Sleep(100 * time.Millisecond)
value, _ := s.Get(key)
Expect(value).To(Equal(workerID))
}(i)
}

wg.Wait()
})

It("can report whether a value exists", func() {
s := syncmap.New[int]()
s.Put("exists", 42)

v1, ok1 := s.Get("exists")
Expect(v1).To(Equal(42))
Expect(ok1).To(BeTrue())

v2, ok2 := s.Get("doesn't exist")
Expect(v2).To(Equal(0))
Expect(ok2).To(BeFalse())
})

It("can delete a value", func() {
const key = "exists"
s := syncmap.New[int]()
s.Put(key, 42)
_, ok1 := s.Get(key)
Expect(ok1).To(BeTrue())

s.Delete(key)
_, ok2 := s.Get(key)
Expect(ok2).To(BeFalse())
})

It("can be marshalled into JSON", func() {
s := syncmap.New[any]()
s.Put("foo", "bar")
s.Put("baz", 42)
Expect(json.Marshal(s)).To(MatchJSON(`{"foo":"bar","baz":42}`))
})

It("can be unmarshalled from JSON", func() {
const input = `{"foo":"bar","baz":42}`
s := syncmap.New[any]()

Expect(json.Unmarshal([]byte(input), s)).To(Succeed())
Expect(json.Marshal(s)).To(MatchJSON(input))
})

It("can return a list of keys", func() {
s := syncmap.New[any]()
s.Put("foo", "bar")
s.Put("baz", 42)
s.Put("quz", false)

Expect(s.Keys()).To(ConsistOf("foo", "baz", "quz"))
})

It("can return a list of values", func() {
s := syncmap.New[string]()
s.Put("foo", "bar")
s.Put("baz", "quz")
s.Put("duz", "fuz")

Expect(s.Values()).To(ConsistOf("bar", "fuz", "quz"))
})
})
92 changes: 31 additions & 61 deletions volume_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"code.cloudfoundry.org/goshims/osshim"
"code.cloudfoundry.org/goshims/timeshim"
"code.cloudfoundry.org/lager/v3"
"code.cloudfoundry.org/volumedriver/internal/syncmap"
"code.cloudfoundry.org/volumedriver/mountchecker"
)

Expand All @@ -32,8 +33,7 @@ type OsHelper interface {
}

type VolumeDriver struct {
volumes map[string]*NfsVolumeInfo
volumesLock sync.RWMutex
volumes *syncmap.SyncMap[*NfsVolumeInfo]
os osshim.Os
filepath filepathshim.Filepath
ioutil ioutilshim.Ioutil
Expand All @@ -46,7 +46,7 @@ type VolumeDriver struct {

func NewVolumeDriver(logger lager.Logger, os osshim.Os, filepath filepathshim.Filepath, ioutil ioutilshim.Ioutil, time timeshim.Time, mountChecker mountchecker.MountChecker, mountPathRoot string, mounter Mounter, oshelper OsHelper) *VolumeDriver {
d := &VolumeDriver{
volumes: map[string]*NfsVolumeInfo{},
volumes: syncmap.New[*NfsVolumeInfo](),
os: os,
filepath: filepath,
ioutil: ioutil,
Expand Down Expand Up @@ -97,17 +97,11 @@ func (d *VolumeDriver) Create(env dockerdriver.Env, createRequest dockerdriver.C
Opts: createRequest.Opts,
}

d.volumesLock.Lock()
defer d.volumesLock.Unlock()

d.volumes[createRequest.Name] = &volInfo
d.volumes.Put(createRequest.Name, &volInfo)
} else {
existing.Opts = createRequest.Opts

d.volumesLock.Lock()
defer d.volumesLock.Unlock()

d.volumes[createRequest.Name] = existing
d.volumes.Put(createRequest.Name, existing)
}

err = d.persistState(driverhttp.EnvWithLogger(logger, env))
Expand All @@ -120,15 +114,12 @@ func (d *VolumeDriver) Create(env dockerdriver.Env, createRequest dockerdriver.C
}

func (d *VolumeDriver) List(_ dockerdriver.Env) dockerdriver.ListResponse {
d.volumesLock.RLock()
defer d.volumesLock.RUnlock()

listResponse := dockerdriver.ListResponse{
Volumes: []dockerdriver.VolumeInfo{},
}

for _, volume := range d.volumes {
listResponse.Volumes = append(listResponse.Volumes, volume.VolumeInfo)
for _, val := range d.volumes.Values() {
listResponse.Volumes = append(listResponse.Volumes, val.VolumeInfo)
}
listResponse.Err = ""
return listResponse
Expand All @@ -150,11 +141,8 @@ func (d *VolumeDriver) Mount(env dockerdriver.Env, mountRequest dockerdriver.Mou

ret := func() dockerdriver.MountResponse {

d.volumesLock.Lock()
defer d.volumesLock.Unlock()

volume := d.volumes[mountRequest.Name]
if volume == nil {
volume, ok := d.volumes.Get(mountRequest.Name)
if !ok {
return dockerdriver.MountResponse{Err: fmt.Sprintf("Volume '%s' must be created before being mounted", mountRequest.Name)}
}

Expand Down Expand Up @@ -202,11 +190,8 @@ func (d *VolumeDriver) Mount(env dockerdriver.Env, mountRequest dockerdriver.Mou
}

func() {
d.volumesLock.Lock()
defer d.volumesLock.Unlock()

volume := d.volumes[mountRequest.Name]
if volume == nil {
volume, ok := d.volumes.Get(mountRequest.Name)
if !ok {
ret = dockerdriver.MountResponse{Err: fmt.Sprintf("Volume '%s' not found", mountRequest.Name)}
} else if err != nil {
if _, ok := err.(dockerdriver.SafeError); ok {
Expand All @@ -228,11 +213,8 @@ func (d *VolumeDriver) Mount(env dockerdriver.Env, mountRequest dockerdriver.Mou
wg.Wait()

return func() dockerdriver.MountResponse {
d.volumesLock.Lock()
defer d.volumesLock.Unlock()

volume := d.volumes[mountRequest.Name]
if volume == nil {
volume, ok := d.volumes.Get(mountRequest.Name)
if !ok {
return dockerdriver.MountResponse{Err: fmt.Sprintf("Volume '%s' not found", mountRequest.Name)}
} else if volume.mountError != "" {
return dockerdriver.MountResponse{Err: volume.mountError}
Expand Down Expand Up @@ -281,10 +263,7 @@ func (d *VolumeDriver) Unmount(env dockerdriver.Env, unmountRequest dockerdriver
return dockerdriver.ErrorResponse{Err: "Missing mandatory 'volume_name'"}
}

d.volumesLock.Lock()
defer d.volumesLock.Unlock()

volume, ok := d.volumes[unmountRequest.Name]
volume, ok := d.volumes.Get(unmountRequest.Name)
if !ok {
logger.Error("failed-no-such-volume-found", fmt.Errorf("could not find volume %s", unmountRequest.Name))

Expand All @@ -307,7 +286,7 @@ func (d *VolumeDriver) Unmount(env dockerdriver.Env, unmountRequest dockerdriver
logger.Info("volume-ref-count-decremented", lager.Data{"name": volume.Name, "count": volume.MountCount})

if volume.MountCount < 1 {
delete(d.volumes, unmountRequest.Name)
d.volumes.Delete(unmountRequest.Name)
}

if err := d.persistState(driverhttp.EnvWithLogger(logger, env)); err != nil {
Expand Down Expand Up @@ -341,9 +320,7 @@ func (d *VolumeDriver) Remove(env dockerdriver.Env, removeRequest dockerdriver.R

logger.Info("removing-volume", lager.Data{"name": removeRequest.Name})

d.volumesLock.Lock()
defer d.volumesLock.Unlock()
delete(d.volumes, removeRequest.Name)
d.volumes.Delete(removeRequest.Name)

if err := d.persistState(driverhttp.EnvWithLogger(logger, env)); err != nil {
return dockerdriver.ErrorResponse{Err: fmt.Sprintf("failed to persist state when removing: %s", err.Error())}
Expand All @@ -368,10 +345,8 @@ func (d *VolumeDriver) Get(env dockerdriver.Env, getRequest dockerdriver.GetRequ

func (d *VolumeDriver) getVolume(env dockerdriver.Env, volumeName string) (*NfsVolumeInfo, error) {
logger := env.Logger().Session("get-volume")
d.volumesLock.RLock()
defer d.volumesLock.RUnlock()

if vol, ok := d.volumes[volumeName]; ok {
if vol, ok := d.volumes.Get(volumeName); ok {
logger.Info("getting-volume", lager.Data{"name": volumeName})
return vol, nil
}
Expand Down Expand Up @@ -483,21 +458,13 @@ func (d *VolumeDriver) restoreState(env dockerdriver.Env) {
logger.Info("failed-to-read-state-file", lager.Data{"err": err, "stateFile": stateFile})
return
}
logger.Info("state", lager.Data{"state": string(stateData)})

state := map[string]*NfsVolumeInfo{}
err = json.Unmarshal(stateData, &state)

logger.Info("state", lager.Data{"state": state})

if err != nil {
if err := json.Unmarshal(stateData, d.volumes); err != nil {
logger.Error("failed-to-unmarshall-state", err, lager.Data{"stateFile": stateFile})
return
}
logger.Info("state-restored", lager.Data{"state-file": stateFile})

d.volumesLock.Lock()
defer d.volumesLock.Unlock()
d.volumes = state
}

func (d *VolumeDriver) unmount(env dockerdriver.Env, name string, mountPath string) error {
Expand Down Expand Up @@ -547,9 +514,10 @@ func (d *VolumeDriver) checkMounts(env dockerdriver.Env) {
logger.Info("start")
defer logger.Info("end")

for key, mount := range d.volumes {
if !d.mounter.Check(driverhttp.EnvWithLogger(logger, env), key, mount.VolumeInfo.Mountpoint) {
delete(d.volumes, key)
for _, key := range d.volumes.Keys() {
mount, ok := d.volumes.Get(key)
if ok && !d.mounter.Check(driverhttp.EnvWithLogger(logger, env), key, mount.VolumeInfo.Mountpoint) {
d.volumes.Delete(key)
}
}
}
Expand All @@ -560,14 +528,16 @@ func (d *VolumeDriver) Drain(env dockerdriver.Env) error {
defer logger.Info("end")

// flush any volumes that are still in our map
for key, mount := range d.volumes {
if mount.Mountpoint != "" && mount.MountCount > 0 {
err := d.unmount(env, mount.Name, mount.Mountpoint)
if err != nil {
logger.Error("drain-unmount-failed", err, lager.Data{"mount-name": mount.Name, "mount-point": mount.Mountpoint})
for _, key := range d.volumes.Keys() {
if mount, ok := d.volumes.Get(key); ok {
if mount.Mountpoint != "" && mount.MountCount > 0 {
err := d.unmount(env, mount.Name, mount.Mountpoint)
if err != nil {
logger.Error("drain-unmount-failed", err, lager.Data{"mount-name": mount.Name, "mount-point": mount.Mountpoint})
}
}
d.volumes.Delete(key)
}
delete(d.volumes, key)
}

d.mounter.Purge(env, d.mountPathRoot)
Expand Down

0 comments on commit 39b2964

Please sign in to comment.