Skip to content

Commit

Permalink
Add option to clear vsicurl cache on entering Env
Browse files Browse the repository at this point in the history
Resolves #1078
  • Loading branch information
Sean Gillies committed Nov 25, 2020
1 parent ea8220b commit 86f64d9
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 4 deletions.
7 changes: 6 additions & 1 deletion rasterio/_env.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import threading
from rasterio._base cimport _safe_osr_release
from rasterio._err import CPLE_BaseError
from rasterio._err cimport exc_wrap_ogrerr, exc_wrap_int
from rasterio._shim cimport set_proj_search_path
from rasterio._shim cimport set_proj_search_path, vsi_curl_clear_cache

from libc.stdio cimport stderr

Expand Down Expand Up @@ -417,3 +417,8 @@ cdef class GDALEnv(ConfigEnv):
def set_proj_data_search_path(path):
"""Set PROJ data search path"""
set_proj_search_path(path)


def vsicurl_clear_cache():
"""Clear GDAL's vsicurl cache"""
vsi_curl_clear_cache()
1 change: 1 addition & 0 deletions rasterio/_shim.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ cdef int io_multi_mask(GDALDatasetH hds, int mode, double xoff, double yoff, dou
cdef const char* osr_get_name(OGRSpatialReferenceH hSrs)
cdef void osr_set_traditional_axis_mapping_strategy(OGRSpatialReferenceH hSrs)
cdef void set_proj_search_path(object path)
cdef void vsi_curl_clear_cache()
4 changes: 4 additions & 0 deletions rasterio/_shim1.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,7 @@ cdef void osr_set_traditional_axis_mapping_strategy(OGRSpatialReferenceH hSrs):

cdef void set_proj_search_path(object path):
os.environ["PROJ_LIB"] = path


cdef void vsi_curl_clear_cache():
pass
4 changes: 4 additions & 0 deletions rasterio/_shim20.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,7 @@ cdef void osr_set_traditional_axis_mapping_strategy(OGRSpatialReferenceH hSrs):

cdef void set_proj_search_path(object path):
os.environ["PROJ_LIB"] = path


cdef void vsi_curl_clear_cache():
pass
5 changes: 5 additions & 0 deletions rasterio/_shim21.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,14 @@ cdef int delete_nodata_value(GDALRasterBandH hBand) except 3:
cdef const char* osr_get_name(OGRSpatialReferenceH hSrs):
return ''


cdef void osr_set_traditional_axis_mapping_strategy(OGRSpatialReferenceH hSrs):
pass


cdef void set_proj_search_path(object path):
os.environ["PROJ_LIB"] = path


cdef void vsi_curl_clear_cache():
pass
9 changes: 9 additions & 0 deletions rasterio/_shim30.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ cdef extern from "ogr_srs_api.h" nogil:
void OSRSetPROJSearchPaths(const char *const *papszPaths)


cdef extern from "cpl_vsi.h" nogil:

void VSICurlClearCache()


from rasterio._err cimport exc_wrap_pointer


Expand Down Expand Up @@ -100,3 +105,7 @@ cdef void set_proj_search_path(object path):
path_c = path_b
paths = CSLAddString(paths, path_c)
OSRSetPROJSearchPaths(paths)


cdef void vsi_curl_clear_cache():
VSICurlClearCache()
10 changes: 8 additions & 2 deletions rasterio/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from rasterio._env import (
GDALEnv, get_gdal_config, set_gdal_config,
GDALDataFinder, PROJDataFinder, set_proj_data_search_path)
GDALDataFinder, PROJDataFinder, set_proj_data_search_path,
vsicurl_clear_cache)
from rasterio.compat import string_types, getargspec
from rasterio.errors import (
EnvError, GDALVersionError, RasterioDeprecationWarning)
Expand Down Expand Up @@ -103,7 +104,7 @@ def default_options(cls):
}

def __init__(self, session=None, aws_unsigned=False, profile_name=None,
session_class=Session.aws_or_dummy, **options):
session_class=Session.aws_or_dummy, clear_vsicurl_cache=False, **options):
"""Create a new GDAL/AWS environment.
Note: this class is a context manager. GDAL isn't configured
Expand All @@ -119,6 +120,8 @@ def __init__(self, session=None, aws_unsigned=False, profile_name=None,
A shared credentials profile name, as per boto3.
session_class : Session, optional
A sub-class of Session.
clear_vsicurl_cache : bool, optional
If True, GDAL's vsicurl cache will be cleared on enter.
**options : optional
A mapping of GDAL configuration options, e.g.,
`CPL_DEBUG=True, CHECK_WITH_INVERT_PROJ=False`.
Expand Down Expand Up @@ -199,6 +202,7 @@ def __init__(self, session=None, aws_unsigned=False, profile_name=None,
else:
self.session = DummySession()

self._clear_vsicurl_cache = clear_vsicurl_cache
self.options = options.copy()
self.context_options = {}

Expand Down Expand Up @@ -275,6 +279,8 @@ def __enter__(self):
self.context_options = getenv()
setenv(**self.options)

if self._clear_vsicurl_cache:
vsicurl_clear_cache()
self.credentialize()

log.debug("Entered env context: %r", self)
Expand Down
1 change: 1 addition & 0 deletions rasterio/gdal.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ cdef extern from "cpl_vsi.h" nogil:
size_t VSIFWriteL(void *buffer, size_t nSize, size_t nCount, VSILFILE *fp)
int VSIStatL(const char *pszFilename, VSIStatBufL *psStatBuf)


cdef extern from "ogr_srs_api.h" nogil:

ctypedef int OGRErr
Expand Down
114 changes: 114 additions & 0 deletions tests/rangehttpserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#!/usr/bin/python
'''
Use this in the same way as Python's SimpleHTTPServer:
python -m RangeHTTPServer [port]
The only difference from SimpleHTTPServer is that RangeHTTPServer supports
'Range:' headers to load portions of files. This is helpful for doing local web
development with genomic data files, which tend to be to large to load into the
browser all at once.
'''

import os
import re

try:
# Python3
from http.server import SimpleHTTPRequestHandler

except ImportError:
# Python 2
from SimpleHTTPServer import SimpleHTTPRequestHandler


def copy_byte_range(infile, outfile, start=None, stop=None, bufsize=16*1024):
'''Like shutil.copyfileobj, but only copy a range of the streams.
Both start and stop are inclusive.
'''
if start is not None: infile.seek(start)
while 1:
to_read = min(bufsize, stop + 1 - infile.tell() if stop else bufsize)
buf = infile.read(to_read)
if not buf:
break
outfile.write(buf)


BYTE_RANGE_RE = re.compile(r'bytes=(\d+)-(\d+)?$')
def parse_byte_range(byte_range):
'''Returns the two numbers in 'bytes=123-456' or throws ValueError.
The last number or both numbers may be None.
'''
if byte_range.strip() == '':
return None, None

m = BYTE_RANGE_RE.match(byte_range)
if not m:
raise ValueError('Invalid byte range %s' % byte_range)

first, last = [x and int(x) for x in m.groups()]
if last and last < first:
raise ValueError('Invalid byte range %s' % byte_range)
return first, last


class RangeRequestHandler(SimpleHTTPRequestHandler):
"""Adds support for HTTP 'Range' requests to SimpleHTTPRequestHandler
The approach is to:
- Override send_head to look for 'Range' and respond appropriately.
- Override copyfile to only transmit a range when requested.
"""
def send_head(self):
if 'Range' not in self.headers:
self.range = None
return SimpleHTTPRequestHandler.send_head(self)
try:
self.range = parse_byte_range(self.headers['Range'])
except ValueError as e:
self.send_error(400, 'Invalid byte range')
return None
first, last = self.range

# Mirroring SimpleHTTPServer.py here
path = self.translate_path(self.path)
f = None
ctype = self.guess_type(path)
try:
f = open(path, 'rb')
except IOError:
self.send_error(404, 'File not found')
return None

fs = os.fstat(f.fileno())
file_len = fs[6]
if first >= file_len:
self.send_error(416, 'Requested Range Not Satisfiable')
return None

self.send_response(206)
self.send_header('Content-type', ctype)
self.send_header('Accept-Ranges', 'bytes')

if last is None or last >= file_len:
last = file_len - 1
response_length = last - first + 1

self.send_header('Content-Range',
'bytes %s-%s/%s' % (first, last, file_len))
self.send_header('Content-Length', str(response_length))
self.send_header('Last-Modified', self.date_time_string(fs.st_mtime))
self.end_headers()
return f

def copyfile(self, source, outputfile):
if not self.range:
return SimpleHTTPRequestHandler.copyfile(self, source, outputfile)

# SimpleHTTPRequestHandler uses shutil.copyfileobj, which doesn't let
# you stop the copying before the end of the file.
start, stop = self.range # set in send_head()
copy_byte_range(source, outputfile, start, stop)
46 changes: 45 additions & 1 deletion tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Collected here to make them easier to skip/xfail.

import os
import shutil
import sys

import boto3
Expand All @@ -21,7 +22,7 @@
from rasterio.rio.main import main_group
from rasterio.session import AWSSession, DummySession, OSSSession, SwiftSession

from .conftest import requires_gdal21
from .conftest import requires_gdal21, requires_gdal3


# Custom markers.
Expand Down Expand Up @@ -894,3 +895,46 @@ def test_open_file_expired_aws_credentials(monkeypatch, caplog, path_rgb_byte_ti
with rasterio.env.Env():
with rasterio.open(path_rgb_byte_tif) as dataset:
assert not dataset.closed


@pytest.fixture
def http_server(tmpdir):
import functools
import multiprocessing
import http.server
from . import rangehttpserver
PORT = 8000
Handler = functools.partial(rangehttpserver.RangeRequestHandler, directory=str(tmpdir))
httpd = http.server.HTTPServer(("", PORT), Handler)
p = multiprocessing.Process(target=httpd.serve_forever)
p.start()
yield
p.terminate()
p.join()


@pytest.mark.xfail(reason="GDAL has cached the first failed request")
def test_vsi_curl_failure_cache(tmpdir, http_server):
"""First failed request was cached"""
with pytest.raises(RasterioIOError):
rasterio.open("/vsicurl/http://localhost:8000/red.tif")

shutil.copy("tests/data/red.tif", str(tmpdir))

with rasterio.open("/vsicurl/http://localhost:8000/red.tif") as src:
assert src.count == 3
assert (src.read(1) == 204).all()


@requires_gdal3(reason="Cache clearing requires GDAL 3+")
def test_vsi_curl_cache_clear(tmpdir, http_server):
"""Clearing cache wipes out previous failure"""
with pytest.raises(RasterioIOError):
rasterio.open("/vsicurl/http://localhost:8000/red.tif")

shutil.copy("tests/data/red.tif", str(tmpdir))

with rasterio.Env(clear_vsicurl_cache=True):
with rasterio.open("/vsicurl/http://localhost:8000/red.tif") as src:
assert src.count == 3
assert (src.read(1) == 204).all()

0 comments on commit 86f64d9

Please sign in to comment.