Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-26343: Use less expansive definition of extension #354

Merged
merged 4 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 17 additions & 4 deletions python/lsst/daf/butler/core/_butlerUri.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,25 @@ def getExtension(self) -> str:
-------
ext : `str`
The file extension (including the ``.``). Can be empty string
if there is no file extension. Will return all file extensions
as a single extension such that ``file.fits.gz`` will return
a value of ``.fits.gz``.
if there is no file extension. Usually returns only the last
file extension unless there is a special extension modifier
indicating file compression, in which case the combined
extension (e.g. ``.fits.gz``) will be returned.
"""
special = {".gz", ".bz2", ".xz", ".fz"}

extensions = self._pathLib(self.path).suffixes
return "".join(extensions)

if not extensions:
return ""

ext = extensions.pop()

# Multiple extensions, decide whether to include the final two
if extensions and ext in special:
ext = f"{extensions[-1]}{ext}"

return ext

def join(self, path: str) -> ButlerURI:
"""Create a new `ButlerURI` with additional path components including
Expand Down
21 changes: 16 additions & 5 deletions python/lsst/daf/butler/core/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ def validateExtension(cls, location: Location) -> None:
writing files. If ``extension`` is `None` only the set of supported
extensions will be examined.
"""
ext = location.getExtension()
supported = set(cls.supportedExtensions)

try:
Expand All @@ -320,12 +319,24 @@ def validateExtension(cls, location: Location) -> None:
except AttributeError:
raise NotImplementedError("No file extension registered with this formatter") from None

if default is not None:
# If extension is implemented as an instance property it won't return
# a string when called as a class propertt. Assume that
# the supported extensions class property is complete.
if default is not None and isinstance(default, str):
supported.add(default)

if ext in supported:
return
raise ValueError(f"Extension '{ext}' on '{location}' is not supported by Formatter '{cls.__name__}'")
# Get the file name from the uri
file = location.uri.basename()

# Check that this file name ends with one of the supported extensions.
# This is less prone to confusion than asking the location for
# its extension and then doing a set comparison
for ext in supported:
if file.endswith(ext):
return

raise ValueError(f"Extension '{location.getExtension()}' on '{location}' "
f"is not supported by Formatter '{cls.__name__}' (supports: {supported})")

def predictPath(self) -> str:
"""Return the path that would be returned by write, without actually
Expand Down
13 changes: 12 additions & 1 deletion python/lsst/daf/butler/tests/testFormatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from __future__ import annotations

__all__ = ("FormatterTest", "DoNothingFormatter", "LenientYamlFormatter", "MetricsExampleFormatter")
__all__ = ("FormatterTest", "DoNothingFormatter", "LenientYamlFormatter", "MetricsExampleFormatter",
"MultipleExtensionsFormatter", "SingleExtensionFormatter")

from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -71,6 +72,16 @@ def validateWriteRecipes(recipes: Optional[Mapping[str, Any]]) -> Optional[Mappi
return recipes


class SingleExtensionFormatter(DoNothingFormatter):
"""A do nothing formatter that has a single extension registered."""
extension = ".fits"


class MultipleExtensionsFormatter(SingleExtensionFormatter):
"""A formatter that has multiple extensions registered."""
supportedExtensions = frozenset({".fits.gz", ".fits.fz", ".fit"})


class LenientYamlFormatter(YamlFormatter):
"""A test formatter that allows any file extension but always reads and
writes YAML."""
Expand Down
24 changes: 23 additions & 1 deletion tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from lsst.daf.butler.tests import DatasetTestHelper
from lsst.daf.butler import (Formatter, FormatterFactory, StorageClass, DatasetType, Config,
FileDescriptor, Location, DimensionUniverse, DimensionGraph)
from lsst.daf.butler.tests.testFormatters import DoNothingFormatter
from lsst.daf.butler.tests.testFormatters import (DoNothingFormatter, MultipleExtensionsFormatter,
SingleExtensionFormatter)

TESTDIR = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -77,6 +78,27 @@ def testFormatter(self):
with self.assertRaises(NotImplementedError):
f.write("str")

def testExtensionValidation(self):
"""Test extension validation"""

for file, single_ok, multi_ok in (("e.fits", True, True),
("e.fit", False, True),
("e.fits.fz", False, True),
("e.txt", False, False),
("e.1.4.fits", True, True),
("e.3.fit", False, True),
("e.1.4.fits.gz", False, True),
):
loc = Location("/a/b/c", file)

for formatter, passes in ((SingleExtensionFormatter, single_ok),
(MultipleExtensionsFormatter, multi_ok)):
if passes:
formatter.validateExtension(loc)
else:
with self.assertRaises(ValueError):
formatter.validateExtension(loc)

def testRegistry(self):
"""Check that formatters can be stored in the registry.
"""
Expand Down
18 changes: 18 additions & 0 deletions tests/test_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,24 @@ def testButlerUriSerialization(self):
self.assertEqual(uri, uri2)
self.assertTrue(uri2.dirLike)

def testUriExtensions(self):
"""Test extension extraction."""

files = (("file.fits.gz", ".fits.gz"),
("file.fits", ".fits"),
("file.fits.xz", ".fits.xz"),
("file.fits.tar", ".tar"),
("file", ""),
("flat_i_sim_1.4_blah.fits.gz", ".fits.gz"),
("flat_i_sim_1.4_blah.txt", ".txt"),
("flat_i_sim_1.4_blah.fits.fz", ".fits.fz"),
("flat_i_sim_1.4_blah.fits.txt", ".txt"),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do support .fits and .fits.fz then I think that we should explicitly test that here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since .fz is not listed as a special extension at the moment getExtension() will return .fz for .fits.fz.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's fine! I would suggest putting those in the list so that it's clear what it will return (and that we know it's returning what we think it should return).


for file, expected in files:
uri = ButlerURI(f"a/b/{file}")
self.assertEqual(uri.getExtension(), expected)

def testFileLocation(self):
root = os.path.abspath(os.path.curdir)
factory = LocationFactory(root)
Expand Down