Skip to content

Commit

Permalink
Fix semantics of Record equality (lektor#1105)
Browse files Browse the repository at this point in the history
* A sources identity depends on alt as well as path

* Fix Attachment.iter_source_filenames

It should return the both the alt-specific and the fallback .lr
filenames (in addition to the attachment filename.)

* SourceObject subclasses should all implement .iter_source_filenames

The SourceObject.source_filename property now just returns the first
of the filenames returned by the iter_source_filenames method.

* Make the identity key (source, alt) explicit for VirtualSourceObjects

Both Records and VirtualSourceObjects are described by the identity
key (path, alt, pad).  That is, if

   obj2 = pad.get(obj.path, alt=obj.path)

then

   obj2 == obj

Records already had __eq__ and __hash__ methods which made this
explicit.  Here we add those methods for VirtualSourceObjects.

Note this allows us to convert Context.referenced_virtual_dependencies
from a dict to a set.  (The key was only being used as an identity key
for the VirtualSourceObjects.)
  • Loading branch information
dairiki committed Sep 11, 2023
1 parent 809b8d5 commit 63ec21d
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 68 deletions.
15 changes: 7 additions & 8 deletions lektor/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def get_asset(pad, filename, parent=None):
return None

try:
stat_obj = os.stat(os.path.join(parent.source_filename, filename))
stat_obj = os.stat(os.path.join(parent._source_filename, filename))
except OSError:
return None
if stat.S_ISDIR(stat_obj.st_mode):
Expand All @@ -24,24 +24,23 @@ def get_asset(pad, filename, parent=None):


class Asset(SourceObject):
# Source specific overrides. The source_filename to none removes
# the inherited descriptor.
source_classification = "asset"
source_filename = None

artifact_extension = ""

def __init__(self, pad, name, path=None, parent=None):
SourceObject.__init__(self, pad)
if parent is not None:
if path is None:
path = name
path = os.path.join(parent.source_filename, path)
self.source_filename = path
path = os.path.join(parent._source_filename, path)
self._source_filename = path

self.name = name
self.parent = parent

def iter_source_filenames(self):
yield self._source_filename

@property
def url_name(self):
name = self.name
Expand Down Expand Up @@ -99,7 +98,7 @@ class Directory(Asset):
@property
def children(self):
try:
files = os.listdir(self.source_filename)
files = os.listdir(self._source_filename)
except OSError:
return

Expand Down
4 changes: 2 additions & 2 deletions lektor/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ def finish_update(self, ctx, exc_info=None):
if exc_info is None:
self._memorize_dependencies(
ctx.referenced_dependencies,
ctx.referenced_virtual_dependencies.values(),
ctx.referenced_virtual_dependencies,
)
self._commit()
return
Expand All @@ -1007,7 +1007,7 @@ def finish_update(self, ctx, exc_info=None):
# use a new database connection that immediately commits.
self._memorize_dependencies(
ctx.referenced_dependencies,
ctx.referenced_virtual_dependencies.values(),
ctx.referenced_virtual_dependencies,
for_failure=True,
)

Expand Down
5 changes: 2 additions & 3 deletions lektor/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, artifact=None, pad=None):

# Processing information
self.referenced_dependencies = set()
self.referenced_virtual_dependencies = {}
self.referenced_virtual_dependencies = set()
self.sub_artifacts = []

self.flow_block_render_stack = []
Expand Down Expand Up @@ -241,8 +241,7 @@ def record_dependency(self, filename, affects_url=None):

def record_virtual_dependency(self, virtual_source):
"""Records a dependency from processing."""
path = virtual_source.path
self.referenced_virtual_dependencies[path] = virtual_source
self.referenced_virtual_dependencies.add(virtual_source)
for coll in self._dependency_collectors:
coll(virtual_source)

Expand Down
66 changes: 22 additions & 44 deletions lektor/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from lektor.imagetools import make_image_thumbnail
from lektor.imagetools import read_exif
from lektor.imagetools import ThumbnailMode
from lektor.sourceobj import SourceObject
from lektor.sourceobj import DBSourceObject
from lektor.sourceobj import VirtualSourceObject
from lektor.utils import cleanup_path
from lektor.utils import cleanup_url_path
Expand Down Expand Up @@ -303,12 +303,12 @@ def __getitem__(self, name):
F = _RecordQueryProxy()


class Record(SourceObject):
class Record(DBSourceObject):
source_classification = "record"
supports_pagination = False

def __init__(self, pad, data, page_num=None):
SourceObject.__init__(self, pad)
super().__init__(pad)
self._data = data
self._bound_data = {}
if page_num is not None and not self.supports_pagination:
Expand All @@ -333,7 +333,7 @@ def datamodel(self):
@property
def alt(self):
"""Returns the alt of this source object."""
return self["_alt"]
return self._data["_alt"]

@property
def is_hidden(self):
Expand Down Expand Up @@ -438,7 +438,7 @@ def _get_url_path(self, alt=None):

@property
def path(self):
return self["_path"]
return self._data["_path"]

def get_sort_key(self, fields):
"""Returns a sort key for the given field specifications specific
Expand Down Expand Up @@ -472,21 +472,11 @@ def __getitem__(self, name):
self._bound_data[name] = rv
return rv

def __eq__(self, other):
if self is other:
return True
if self.__class__ != other.__class__:
return False
return self["_path"] == other["_path"]

def __hash__(self):
return hash(self.path)

def __repr__(self):
return "<%s model=%r path=%r%s%s>" % (
self.__class__.__name__,
self["_model"],
self["_path"],
self._data["_model"],
self._data["_path"],
self.alt != PRIMARY_ALT and " alt=%r" % self.alt or "",
self.page_num is not None and " page_num=%r" % self.page_num or "",
)
Expand Down Expand Up @@ -545,7 +535,7 @@ class Page(Record):

@cached_property
def path(self):
rv = self["_path"]
rv = self._data["_path"]
if self.page_num is not None:
rv = "%s@%s" % (rv, self.page_num)
return rv
Expand All @@ -555,21 +545,16 @@ def record(self):
if self.page_num is None:
return self
return self.pad.get(
self["_path"], persist=self.pad.cache.is_persistent(self), alt=self.alt
self._data["_path"],
persist=self.pad.cache.is_persistent(self),
alt=self.alt,
)

@property
def source_filename(self):
if self.alt != PRIMARY_ALT:
return os.path.join(
self.pad.db.to_fs_path(self["_path"]), "contents+%s.lr" % self.alt
)
return os.path.join(self.pad.db.to_fs_path(self["_path"]), "contents.lr")

def iter_source_filenames(self):
yield self.source_filename
fs_path = self.pad.db.to_fs_path(self._data["_path"])
if self.alt != PRIMARY_ALT:
yield os.path.join(self.pad.db.to_fs_path(self["_path"]), "contents.lr")
yield os.path.join(fs_path, f"contents+{self.alt}.lr")
yield os.path.join(fs_path, "contents.lr")

@property
def url_path(self):
Expand Down Expand Up @@ -665,12 +650,12 @@ def children(self):
repl_query = self.datamodel.get_child_replacements(self)
if repl_query is not None:
return repl_query.include_undiscoverable(False)
return Query(path=self["_path"], pad=self.pad, alt=self.alt)
return Query(path=self._data["_path"], pad=self.pad, alt=self.alt)

@property
def attachments(self):
"""Returns a query for the attachments of this record."""
return AttachmentsQuery(path=self["_path"], pad=self.pad, alt=self.alt)
return AttachmentsQuery(path=self._data["_path"], pad=self.pad, alt=self.alt)

def has_prev(self):
return self.get_siblings().prev_page is not None
Expand Down Expand Up @@ -723,14 +708,6 @@ class Attachment(Record):

is_attachment = True

@property
def source_filename(self):
if self.alt != PRIMARY_ALT:
suffix = "+%s.lr" % self.alt
else:
suffix = ".lr"
return self.pad.db.to_fs_path(self["_path"]) + suffix

def _is_considered_hidden(self):
# Attachments are only considered hidden if they have been
# configured as such. This means that even if a record itself is
Expand All @@ -746,7 +723,7 @@ def record(self):

@property
def attachment_filename(self):
return self.pad.db.to_fs_path(self["_path"])
return self.pad.db.to_fs_path(self._data["_path"])

@property
def parent(self):
Expand All @@ -764,10 +741,11 @@ def get_fallback_record_label(self, lang):
return self["_id"]

def iter_source_filenames(self):
yield self.source_filename
attachment_filename = self.attachment_filename
if self.alt != PRIMARY_ALT:
yield self.pad.db.to_fs_path(self["_path"]) + ".lr"
yield self.attachment_filename
yield f"{attachment_filename}+{self.alt}.lr"
yield f"{attachment_filename}.lr"
yield attachment_filename

@property
def url_path(self):
Expand Down Expand Up @@ -1736,7 +1714,7 @@ def get_virtual(self, record, virtual_path):
if pieces[0].isdigit():
if len(pieces) == 1:
return self.get(
record["_path"], alt=record.alt, page_num=int(pieces[0])
record._data["_path"], alt=record.alt, page_num=int(pieces[0])
)
return None

Expand Down
62 changes: 52 additions & 10 deletions lektor/sourceobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ def alt(self):

@property
def source_filename(self):
"""The primary source filename of this source object."""
"""The primary source filename of this source object.
In general, subclasses should implement/override ``iter_source_filenames``
rather than this property.
"""
source_filenames = self.iter_source_filenames()
return next(iter(source_filenames), None)

is_hidden = False
is_discoverable = True
Expand All @@ -43,13 +49,16 @@ def is_undiscoverable(self):
return not self.is_discoverable

def iter_source_filenames(self):
fn = self.source_filename
if fn is not None:
yield self.source_filename
"""An iterable of the source filenames for this source object.
The first returned filename should be the "primary" one.
"""
# pylint: disable=no-self-use
return ()

def iter_virtual_sources(self):
# pylint: disable=no-self-use
return []
return ()

@property
def url_path(self):
Expand Down Expand Up @@ -205,13 +214,43 @@ def _resolve_url(
return resolved.geturl()


class VirtualSourceObject(SourceObject):
class DBSourceObject(SourceObject):
"""This is the base class for objects that live in the lektor db.
I.e. this is the type of object returned by pad.get().
"""

@property
def path(self):
"""Return the full database path to the source object.
All DBSourceObjects must have paths.
"""
raise NotImplementedError()

# XXX: move SourceObject.url_to here?

def __eq__(self, other):
if other is self:
return True # optimization
if other.__class__ is not self.__class__:
return False # optimization
return (
other.alt == self.alt and other.path == self.path and other.pad == self.pad
)

def __hash__(self):
return hash((self.path, self.alt))


class VirtualSourceObject(DBSourceObject):
"""Virtual source objects live below a parent record but do not
originate from the source tree with a separate file.
"""

def __init__(self, record):
SourceObject.__init__(self, record.pad)
super().__init__(record.pad)
self.record = record

@property
Expand All @@ -234,9 +273,12 @@ def parent(self):
def alt(self):
return self.record.alt

@property
def source_filename(self):
return self.record.source_filename
def iter_source_filenames(self):
# This is a default. However, if artifacts produced from a
# particular virtual source type do not explicitly vary with
# the parent record, it may make sense to override this to
# return an empty (or some other) list of file names.
return self.record.iter_source_filenames()

def iter_virtual_sources(self):
yield self
2 changes: 1 addition & 1 deletion tests/test_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_asset_children(asset, child_names):

@pytest.mark.parametrize("asset_path", ["/static"])
def test_asset_children_no_children_if_dir_unreadable(asset):
asset.source_filename += "-missing"
asset._source_filename += "-missing"
assert len(set(asset.children)) == 0


Expand Down
38 changes: 38 additions & 0 deletions tests/test_sourceobj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest


@pytest.mark.parametrize(
"key1, key2, is_eq",
[
({"path": "/"}, {"path": "/"}, True),
({"path": "/"}, {"path": "/", "alt": "en"}, True),
({"path": "/"}, {"path": "/blog"}, False),
({"path": "/blog/post1"}, {"path": "/blog/post1@siblings"}, False),
({"path": "/"}, {"path": "/", "page_num": 1}, False),
({"path": "/"}, {"path": "/", "alt": "de"}, False),
({"path": "/", "alt": "en"}, {"path": "/", "alt": "de"}, False),
],
)
def test_records_eq(pad, key1, key2, is_eq):
r1 = pad.get(**key1)
pad.cache.flush()
r2 = pad.get(**key2)
if is_eq:
assert r1 == r2
assert hash(r1) == hash(r2)
else:
assert r1 != r2


def test_records_from_different_pads_ne(env):
pad1 = env.new_pad()
pad2 = env.new_pad()
assert pad1.get("/") == pad1.get("/")
assert pad1.get("/") != pad2.get("/")


def test_asset_ne_record(pad):
record = pad.get("/")
asset = pad.get_asset("/")
assert record != asset
assert asset != record

0 comments on commit 63ec21d

Please sign in to comment.