Skip to content

Commit

Permalink
Merge pull request #34 from dask/invalidate-cache
Browse files Browse the repository at this point in the history
Invalidate rather than update cache
  • Loading branch information
mrocklin committed May 3, 2016
2 parents 308e2bf + 94029c8 commit e75362e
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 84 deletions.
123 changes: 59 additions & 64 deletions s3fs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from botocore.exceptions import ClientError, ParamValidationError
from botocore.client import Config

from .utils import read_block
from .utils import read_block, raises

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -185,16 +185,6 @@ def get_delegated_s3pars(self, exp=3600):
return {'key': cred['AccessKeyId'], 'secret': cred['SecretAccessKey'],
'token': cred['SessionToken'], 'anon': False}

def refresh_off(self):
""" Block auto-refresh when writing.
Use in conjunction with `refresh_on()` when writing many files to S3.
"""
self.no_refresh = True

def refresh_on(self):
self.no_refresh = False

def __getstate__(self):
d = self.__dict__.copy()
del d['s3']
Expand Down Expand Up @@ -271,12 +261,12 @@ def _ls(self, path, refresh=False):
files = self.dirs[bucket]
return files

def ls(self, path, detail=False):
def ls(self, path, detail=False, refresh=False):
""" List single "directory" with or without details """
if path.startswith('s3://'):
path = path[len('s3://'):]
path = path.rstrip('/')
files = self._ls(path)
files = self._ls(path, refresh=refresh)
if path:
pattern = re.compile(path + '/[^/]*.$')
files = [f for f in files if pattern.match(f['Key']) is not None]
Expand All @@ -290,7 +280,7 @@ def ls(self, path, detail=False):
else:
return [f['Key'] for f in files]

def info(self, path):
def info(self, path, refresh=False):
""" Detail on the specific file pointed to by path.
NB: path has trailing '/' stripped to work as `ls` does, so key
Expand All @@ -299,18 +289,19 @@ def info(self, path):
if path.startswith('s3://'):
path = path[len('s3://'):]
path = path.rstrip('/')
files = self._ls(path)
files = self._ls(path, refresh=refresh)
files = [f for f in files if f['Key'].rstrip('/') == path]
if len(files) == 1:
return files[0]
else:
raise IOError("File not found: %s" % path)

def walk(self, path):
def walk(self, path, refresh=False):
""" Return all entries below path """
if path.startswith('s3://'):
path = path[len('s3://'):]
return [f['Key'] for f in self._ls(path) if f['Key'].rstrip('/'
filenames = self._ls(path, refresh=refresh)
return [f['Key'] for f in filenames if f['Key'].rstrip('/'
).startswith(path.rstrip('/') + '/')]

def glob(self, path):
Expand Down Expand Up @@ -364,7 +355,8 @@ def exists(self, path):
if split_path(path)[1]:
return bool(self.ls(path))
else:
return path in self.ls('')
return (path in self.ls('') and
not raises(FileNotFoundError, lambda: self.ls(path)))

def cat(self, path):
""" Returns contents of file """
Expand Down Expand Up @@ -443,8 +435,7 @@ def merge(self, path, filelist):
part_info = {'Parts': parts}
self.s3.complete_multipart_upload(Bucket=bucket, Key=key,
UploadId=mpu['UploadId'], MultipartUpload=part_info)
self._ls(bucket, refresh=True)

self.invalidate_cache(bucket)

def copy(self, path1, path2):
""" Copy file between locations on S3 """
Expand All @@ -455,7 +446,7 @@ def copy(self, path1, path2):
CopySource='/'.join([buc1, key1]))
except (ClientError, ParamValidationError):
raise IOError('Copy failed', (path1, path2))
self._ls(path2, refresh=True)
self.invalidate_cache(buc2)

def rm(self, path, recursive=False):
"""
Expand All @@ -481,18 +472,24 @@ def rm(self, path, recursive=False):
self.s3.delete_object(Bucket=bucket, Key=key)
except ClientError:
raise IOError('Delete key failed', (bucket, key))
self._ls(path, refresh=True)
self.invalidate_cache(bucket)
else:
if not self.s3.list_objects(Bucket=bucket).get('Contents'):
try:
self.s3.delete_bucket(Bucket=bucket)
except ClientError:
raise IOError('Delete bucket failed', bucket)
self.dirs.pop(bucket, None)
self._ls('', refresh=True)
self.invalidate_cache(bucket)
else:
raise IOError('Not empty', path)

def invalidate_cache(self, bucket=None):
if bucket is None:
self.dirs.clear()
elif bucket in self.dirs:
del self.dirs[bucket]

def touch(self, path):
"""
Create empty key
Expand All @@ -502,11 +499,11 @@ def touch(self, path):
bucket, key = split_path(path)
if key:
self.s3.put_object(Bucket=bucket, Key=key)
self._ls(bucket, refresh=True)
self.invalidate_cache(bucket)
else:
try:
self.s3.create_bucket(Bucket=bucket)
self._ls("", refresh=True)
self.invalidate_cache('')
except (ClientError, ParamValidationError):
raise IOError('Bucket create failed', path)

Expand Down Expand Up @@ -557,26 +554,6 @@ def read_block(self, fn, offset, length, delimiter=None):
return bytes


@contextmanager
def no_refresh(s3fs):
""" Wrap an s3fs with this to temporarily block freshing filecache on writes.
Use this if writing many small files to a bucket.
The filelist will only be refreshed by the next writing action, or
explicit call to `s3fs._ls(bucket, refresh=True)`.
Usage
-----
>>> with no_refresh(s3fs) as fs: # doctest: +SKIP
[fs.touch('mybucket/file%i'%i) for i in range(1500)] # doctest: +SKIP
"""
s3fs.refresh_off()
try:
yield s3fs
finally:
s3fs.refresh_on()


class S3File(object):
"""
Open S3 key as a file. Data is only loaded and cached on demand.
Expand Down Expand Up @@ -618,28 +595,30 @@ def __init__(self, s3, path, mode='rb', block_size=5 * 2 ** 20):
self.end = None
self.closed = False
self.trim = True
self.mpu = None
if mode in {'wb', 'ab'}:
self.buffer = io.BytesIO()
self.parts = []
self.size = 0
if block_size < 5 * 2 ** 20:
raise ValueError('Block size must be >=5MB')
try:
self.mpu = s3.s3.create_multipart_upload(Bucket=bucket, Key=key)
except (ClientError, ParamValidationError):
raise IOError('Open for write failed', path)
self.forced = False
if mode == 'ab' and s3.exists(path):
self.size = s3.info(path)['Size']
if self.size < 5*2**20:
# existing file too small for multi-upload: download
self.write(s3.cat(path))
else:
try:
self.mpu = s3.s3.create_multipart_upload(Bucket=bucket, Key=key)
except (ClientError, ParamValidationError):
raise IOError('Open for write failed', path)
self.loc = self.size
out = self.s3.s3.upload_part_copy(Bucket=self.bucket, Key=self.key,
PartNumber=1, UploadId=self.mpu['UploadId'],
CopySource=path)
self.parts.append({'PartNumber': 1, 'ETag': out['CopyPartResult']['ETag']})
self.parts.append({'PartNumber': 1,
'ETag': out['CopyPartResult']['ETag']})
else:
try:
self.size = self.info()['Size']
Expand Down Expand Up @@ -782,15 +761,17 @@ def flush(self, force=False, retries=10):
"""
Write buffered data to S3.
Uploads the current buffer, if it is larger than the block-size.
Due to S3 multi-upload policy, you can only safely force flush to S3
when you are finished writing. It is unsafe to call this function
repeatedly.
Parameters
----------
force : bool (True)
Whether to write even if the buffer is less than the blocksize. If
less than the S3 part minimum (5MB), must be last block.
force : bool
When closing, write the last block even if it is smaller than
blocks are allowed to be.
"""
if self.mode in {'wb', 'ab'} and not self.closed:
if self.buffer.tell() < self.blocksize and not force:
Expand All @@ -799,27 +780,36 @@ def flush(self, force=False, retries=10):
if self.buffer.tell() == 0:
# no data in the buffer to write
return
if force and self.forced and self.buffer.tell() < 5 * 2 ** 20:
raise IOError('Under-sized block already written')
if force and self.buffer.tell() < 5 * 2 ** 20:
if force and self.forced:
raise ValueError("Force flush cannot be called more than once")
if force:
self.forced = True

self.buffer.seek(0)
part = len(self.parts) + 1
i = 0

try:
self.mpu = self.mpu or self.s3.s3.create_multipart_upload(
Bucket=self.bucket, Key=self.key)
except (ClientError, ParamValidationError):
raise IOError('Initating write failed: %s' % self.path)

while True:
try:
out = self.s3.s3.upload_part(Bucket=self.bucket, Key=self.key,
out = self.s3.s3.upload_part(Bucket=self.bucket,
PartNumber=part, UploadId=self.mpu['UploadId'],
Body=self.buffer.read())
Body=self.buffer.read(), Key=self.key)
break
except S3_RETRYABLE_ERRORS:
if i < retries:
logger.debug('Exception %e on S3 upload, retrying',
logger.debug('Exception %e on S3 write, retrying',
exc_info=True)
i += 1
continue
else:
raise IOError('Write failed after %i retries'%retries, self)
raise IOError('Write failed after %i retries' % retries,
self)
except:
raise IOError('Write failed', self)
self.parts.append({'PartNumber': part, 'ETag': out['ETag']})
Expand All @@ -833,20 +823,25 @@ def close(self):
"""
if self.closed:
return
self.flush(True)
self.cache = None
self.closed = True
if self.mode in {'wb', 'ab'}:
if self.parts:
self.flush(force=True)
part_info = {'Parts': self.parts}
self.s3.s3.complete_multipart_upload(Bucket=self.bucket,
Key=self.key,
UploadId=self.mpu[
'UploadId'],
MultipartUpload=part_info)
else:
self.s3.s3.put_object(Bucket=self.bucket, Key=self.key)
self.s3._ls(self.bucket, refresh=True)
self.buffer.seek(0)
try:
self.s3.s3.put_object(Bucket=self.bucket, Key=self.key,
Body=self.buffer.read())
except (ClientError, ParamValidationError):
raise IOError('Write failed: %s' % self.path)
self.s3.invalidate_cache(self.bucket)
self.closed = True

def readable(self):
"""Return whether the S3File was opened for reading"""
Expand Down
35 changes: 15 additions & 20 deletions s3fs/tests/test_s3fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import io
import pytest
from itertools import chain
from s3fs.core import S3FileSystem, no_refresh
from s3fs.core import S3FileSystem
from s3fs.utils import seek_delimiter, ignoring, tmpfile
import moto

Expand Down Expand Up @@ -203,24 +203,12 @@ def test_s3_ls(s3):


def test_s3_big_ls(s3):
with no_refresh(s3) as s3:
for x in range(1200):
s3.touch(test_bucket_name+'/thousand/%i.part'%x)
s3._ls(test_bucket_name, refresh=True)
for x in range(1200):
s3.touch(test_bucket_name+'/thousand/%i.part'%x)
assert len(s3.walk(test_bucket_name)) > 1200
s3.rm(test_bucket_name+'/thousand/', recursive=True)


def test_no_refresh(s3):
set1 = s3.walk(test_bucket_name)
s3.refresh_off()
s3.touch(test_bucket_name+'/another')
assert set1 == s3.walk(test_bucket_name)
s3.refresh_on()
s3.touch(test_bucket_name+'/yet_another')
assert len(set1) < len(s3.walk(test_bucket_name))


def test_s3_ls_detail(s3):
L = s3.ls(test_bucket_name+'/nested', detail=True)
assert all(isinstance(item, dict) for item in L)
Expand Down Expand Up @@ -413,6 +401,8 @@ def test_write_small(s3):
with s3.open(test_bucket_name+'/test', 'wb') as f:
f.write(b'hello')
assert s3.cat(test_bucket_name+'/test') == b'hello'
s3.open(test_bucket_name+'/test', 'wb').close()
assert s3.info(test_bucket_name+'/test')['Size'] == 0

def test_write_fails(s3):
with pytest.raises(NotImplementedError):
Expand All @@ -432,14 +422,17 @@ def test_write_fails(s3):
with pytest.raises(ValueError):
f.write(b'hello')
with pytest.raises((OSError, IOError)):
s3.open('nonexistentbucket/temp', 'wb')
s3.open('nonexistentbucket/temp', 'wb').close()

def test_write_blocks(s3):
with s3.open(test_bucket_name+'/temp', 'wb') as f:
f.write(b'a' * 2*2**20)
assert f.buffer.tell() == 2*2**20
assert not(f.parts)
f.write(b'a' * 2*2**20)
f.write(b'a' * 2*2**20)
assert f.mpu
assert f.parts
assert s3.info(test_bucket_name+'/temp')['Size'] == 6*2**20
with s3.open(test_bucket_name+'/temp', 'wb', block_size=10*2**20) as f:
f.write(b'a' * 15*2**20)
Expand Down Expand Up @@ -580,21 +573,23 @@ def test_append(s3):
assert s3.cat(test_bucket_name+'/nested/file1') == data
with s3.open(test_bucket_name+'/nested/file1', 'ab') as f:
f.write(b'extra') # append, write, small file
assert s3.cat(test_bucket_name+'/nested/file1') == data+b'extra'
assert s3.cat(test_bucket_name+'/nested/file1') == data+b'extra'

with s3.open(a, 'wb') as f:
f.write(b'a' * 10*2**20)
with s3.open(a, 'ab') as f:
pass # append, no write, big file
pass # append, no write, big file
assert s3.cat(a) == b'a' * 10*2**20

with s3.open(a, 'ab') as f:
f.write(b'extra') # append, small write, big file
assert f.parts
assert f.tell() == 10*2**20
f.write(b'extra') # append, small write, big file
assert s3.cat(a) == b'a' * 10*2**20 + b'extra'

with s3.open(a, 'ab') as f:
assert f.tell() == 10*2**20 + 5
f.write(b'b' * 10*2**20) # append, big write, big file
f.write(b'b' * 10*2**20) # append, big write, big file
assert f.tell() == 20*2**20 + 5
assert s3.cat(a) == b'a' * 10*2**20 + b'extra' + b'b' *10*2**20

Expand Down

0 comments on commit e75362e

Please sign in to comment.