Skip to content

Commit

Permalink
Tracking cache requests (#1566)
Browse files Browse the repository at this point in the history
  • Loading branch information
betolink committed Apr 11, 2024
1 parent 2dd9355 commit 05e7d80
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 28 deletions.
97 changes: 83 additions & 14 deletions fsspec/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,13 @@ class BaseCache:

def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
self.blocksize = blocksize
self.nblocks = 0
self.fetcher = fetcher
self.size = size
self.hit_count = 0
self.miss_count = 0
# the bytes that we actually requested
self.total_requested_bytes = 0

def _fetch(self, start: int | None, stop: int | None) -> bytes:
if start is None:
Expand All @@ -68,6 +73,36 @@ def _fetch(self, start: int | None, stop: int | None) -> bytes:
return b""
return self.fetcher(start, stop)

def _reset_stats(self) -> None:
"""Reset hit and miss counts for a more ganular report e.g. by file."""
self.hit_count = 0
self.miss_count = 0
self.total_requested_bytes = 0

def _log_stats(self) -> str:
"""Return a formatted string of the cache statistics."""
if self.hit_count == 0 and self.miss_count == 0:
# a cache that does nothing, this is for logs only
return ""
return " , %s: %d hits, %d misses, %d total requested bytes" % (
self.name,
self.hit_count,
self.miss_count,
self.total_requested_bytes,
)

def __repr__(self) -> str:
# TODO: use rich for better formatting
return f"""
<{self.__class__.__name__}:
block size : {self.blocksize}
block count : {self.nblocks}
file size : {self.size}
cache hits : {self.hit_count}
cache misses: {self.miss_count}
total requested bytes: {self.total_requested_bytes}>
"""


class MMapCache(BaseCache):
"""memory-mapped sparse file cache
Expand Down Expand Up @@ -126,13 +161,18 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
start_block = start // self.blocksize
end_block = end // self.blocksize
need = [i for i in range(start_block, end_block + 1) if i not in self.blocks]
hits = [i for i in range(start_block, end_block + 1) if i in self.blocks]
self.miss_count += len(need)
self.hit_count += len(hits)
while need:
# TODO: not a for loop so we can consolidate blocks later to
# make fewer fetch calls; this could be parallel
i = need.pop(0)

sstart = i * self.blocksize
send = min(sstart + self.blocksize, self.size)
logger.debug(f"MMap get block #{i} ({sstart}-{send}")
self.total_requested_bytes += send - sstart
logger.debug(f"MMap get block #{i} ({sstart}-{send})")
self.cache[sstart:send] = self.fetcher(sstart, send)
self.blocks.add(i)

Expand Down Expand Up @@ -176,16 +216,20 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
l = end - start
if start >= self.start and end <= self.end:
# cache hit
self.hit_count += 1
return self.cache[start - self.start : end - self.start]
elif self.start <= start < self.end:
# partial hit
self.miss_count += 1
part = self.cache[start - self.start :]
l -= len(part)
start = self.end
else:
# miss
self.miss_count += 1
part = b""
end = min(self.size, end + self.blocksize)
self.total_requested_bytes += end - start
self.cache = self.fetcher(start, end) # new block replaces old
self.start = start
self.end = self.start + len(self.cache)
Expand All @@ -202,24 +246,39 @@ class FirstChunkCache(BaseCache):
name = "first"

def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
if blocksize > size:
# this will buffer the whole thing
blocksize = size
super().__init__(blocksize, fetcher, size)
self.cache: bytes | None = None

def _fetch(self, start: int | None, end: int | None) -> bytes:
start = start or 0
end = end or self.size
if start > self.size:
logger.debug("FirstChunkCache: requested start > file size")
return b""

end = min(end, self.size)

if start < self.blocksize:
if self.cache is None:
self.miss_count += 1
if end > self.blocksize:
self.total_requested_bytes += end
data = self.fetcher(0, end)
self.cache = data[: self.blocksize]
return data[start:]
self.cache = self.fetcher(0, self.blocksize)
self.total_requested_bytes += self.blocksize
part = self.cache[start:end]
if end > self.blocksize:
self.total_requested_bytes += end - self.blocksize
part += self.fetcher(self.blocksize, end)
self.hit_count += 1
return part
else:
self.miss_count += 1
self.total_requested_bytes += end - start
return self.fetcher(start, end)


Expand Down Expand Up @@ -256,12 +315,6 @@ def __init__(
self.maxblocks = maxblocks
self._fetch_block_cached = functools.lru_cache(maxblocks)(self._fetch_block)

def __repr__(self) -> str:
return (
f"<BlockCache blocksize={self.blocksize}, "
f"size={self.size}, nblocks={self.nblocks}>"
)

def cache_info(self):
"""
The statistics on the block cache.
Expand Down Expand Up @@ -319,6 +372,8 @@ def _fetch_block(self, block_number: int) -> bytes:

start = block_number * self.blocksize
end = start + self.blocksize
self.total_requested_bytes += end - start
self.miss_count += 1
logger.info("BlockCache fetching block %d", block_number)
block_contents = super()._fetch(start, end)
return block_contents
Expand All @@ -339,6 +394,7 @@ def _read_cache(
start_pos = start % self.blocksize
end_pos = end % self.blocksize

self.hit_count += 1
if start_block_number == end_block_number:
block: bytes = self._fetch_block_cached(start_block_number)
return block[start_pos:end_pos]
Expand Down Expand Up @@ -404,6 +460,7 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
):
# cache hit: we have all the required data
offset = start - self.start
self.hit_count += 1
return self.cache[offset : offset + end - start]

if self.blocksize:
Expand All @@ -418,27 +475,34 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
self.end is None or end > self.end
):
# First read, or extending both before and after
self.total_requested_bytes += bend - start
self.miss_count += 1
self.cache = self.fetcher(start, bend)
self.start = start
else:
assert self.start is not None
assert self.end is not None
self.miss_count += 1

if start < self.start:
if self.end is None or self.end - end > self.blocksize:
self.total_requested_bytes += bend - start
self.cache = self.fetcher(start, bend)
self.start = start
else:
self.total_requested_bytes += self.start - start
new = self.fetcher(start, self.start)
self.start = start
self.cache = new + self.cache
elif self.end is not None and bend > self.end:
if self.end > self.size:
pass
elif end - self.end > self.blocksize:
self.total_requested_bytes += bend - start
self.cache = self.fetcher(start, bend)
self.start = start
else:
self.total_requested_bytes += bend - self.end
new = self.fetcher(self.end, bend)
self.cache = self.cache + new

Expand Down Expand Up @@ -470,10 +534,13 @@ def __init__(
) -> None:
super().__init__(blocksize, fetcher, size) # type: ignore[arg-type]
if data is None:
self.miss_count += 1
self.total_requested_bytes += self.size
data = self.fetcher(0, self.size)
self.data = data

def _fetch(self, start: int | None, stop: int | None) -> bytes:
self.hit_count += 1
return self.data[start:stop]


Expand Down Expand Up @@ -551,6 +618,7 @@ def _fetch(self, start: int | None, stop: int | None) -> bytes:
# are allowed to pad reads beyond the
# buffer with zero
out += b"\x00" * (stop - start - len(out))
self.hit_count += 1
return out
else:
# The request ends outside a known range,
Expand All @@ -572,6 +640,8 @@ def _fetch(self, start: int | None, stop: int | None) -> bytes:
f"IO/caching performance may be poor!"
)
logger.debug(f"KnownPartsOfAFile cache fetching {start}-{stop}")
self.total_requested_bytes += stop - start
self.miss_count += 1
return out + super()._fetch(start, stop)


Expand Down Expand Up @@ -676,12 +746,6 @@ def __init__(
self._fetch_future: Future[bytes] | None = None
self._fetch_future_lock = threading.Lock()

def __repr__(self) -> str:
return (
f"<BackgroundBlockCache blocksize={self.blocksize}, "
f"size={self.size}, nblocks={self.nblocks}>"
)

def cache_info(self) -> UpdatableLRU.CacheInfo:
"""
The statistics on the block cache.
Expand Down Expand Up @@ -799,6 +863,8 @@ def _fetch_block(self, block_number: int, log_info: str = "sync") -> bytes:
start = block_number * self.blocksize
end = start + self.blocksize
logger.info("BlockCache fetching block (%s) %d", log_info, block_number)
self.total_requested_bytes += end - start
self.miss_count += 1
block_contents = super()._fetch(start, end)
return block_contents

Expand All @@ -818,6 +884,9 @@ def _read_cache(
start_pos = start % self.blocksize
end_pos = end % self.blocksize

# kind of pointless to count this as a hit, but it is
self.hit_count += 1

if start_block_number == end_block_number:
block = self._fetch_block_cached(start_block_number)
return block[start_pos:end_pos]
Expand Down
9 changes: 8 additions & 1 deletion fsspec/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,11 +1841,18 @@ def read(self, length=-1):
length = self.size - self.loc
if self.closed:
raise ValueError("I/O operation on closed file.")
logger.debug("%s read: %i - %i", self, self.loc, self.loc + length)
if length == 0:
# don't even bother calling fetch
return b""
out = self.cache._fetch(self.loc, self.loc + length)

logger.debug(
"%s read: %i - %i %s",
self,
self.loc,
self.loc + length,
self.cache._log_stats(),
)
self.loc += len(out)
return out

Expand Down
Loading

0 comments on commit 05e7d80

Please sign in to comment.