From 9eb4d5bfc6d0f0a1135ececa477810983bea0f6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Gl=C3=A4=C3=9Fle?= Date: Sat, 3 May 2014 12:04:46 +0200 Subject: [PATCH] Support deep attribute lookup --- CHANGES.rst | 1 + madseq.py | 46 +++++++++++++++++++++++++++++++++------------- test.py | 16 ++++++++++++---- 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 01f833a..ca5a28e 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -5,6 +5,7 @@ Changelog ===== - use `semantic versioning ` +- fix deep attribute lookup for elements 0.3 diff --git a/madseq.py b/madseq.py index 0473aa6..6e2f4f8 100755 --- a/madseq.py +++ b/madseq.py @@ -268,7 +268,7 @@ class Element(object): __slots__ = ['name', 'type', 'args'] - def __init__(self, name, type, args): + def __init__(self, name, type, args, base=None): """ Initialize an Element object. @@ -279,6 +279,7 @@ def __init__(self, name, type, args): self.name = stri(name) self.type = stri(type) self.args = args + self._base = base @classmethod def parse(cls, text): @@ -290,7 +291,7 @@ def copy(self): return self.__class__(self.name, self.type, self.args.copy()) def __getattr__(self, key): - return self.args[key] + return self[key] def __setattr__(self, key, val): if key in self.__slots__: @@ -304,11 +305,30 @@ def __contains__(self, key): def __delattr__(self, key): del self.args[key] + def __getitem__(self, key): + try: + return self.args[key] + except KeyError: + if self._base: + return self._base[key] + raise + def get(self, key, default=None): - return self.args.get(key, default) + try: + return self[key] + except KeyError: + return default def pop(self, key, *default): - return self.args.pop(key, *default) + try: + return self.args.pop(key) + except KeyError: + try: + return self._base[key] + except (KeyError, TypeError): + if default: + return default[0] + raise def __str__(self): """Output element in MAD-X format.""" @@ -373,9 +393,9 @@ def __init__(self, slicing): self.transforms = [ElementTransform(s) for s in slicing] + [] self.transforms.append(ElementTransform({})) - def __call__(self, node, document): + def __call__(self, node, defs): if isinstance(node, (Element, Sequence)): - document._defs[node.name] = node + defs[node.name] = node if not isinstance(node, Sequence): return node @@ -386,9 +406,11 @@ def __call__(self, node, document): refer = self.offsets[str(first.get('refer', 'centre'))] def transform(elem, offset): + if elem.type: + elem._base = defs.get(elem.type) for t in self.transforms: if t.match(elem): - return t.replace(elem, offset, refer, document._defs.get(elem.type)) + return t.replace(elem, offset, refer) templates = [] # predefined element templates elements = [] # actual elements to put in sequence @@ -463,10 +485,8 @@ def make_optic(elem, elem_len, slice_num): else: raise ValueError("Unknown slicing style: {!r}".format(style)) - def replace(self, elem, offset, refer, parent): - elem_len = elem.get('L') - if elem_len is None: - elem_len = parent.get('L', 0) if parent else 0 + def replace(self, elem, offset, refer): + elem_len = elem.get('L', 0) slice_num = self._get_slice_num(elem_len) or 1 optic = self._makeoptic(elem, slice_num) elem = self._stripelem(elem) @@ -628,11 +648,11 @@ class Document(list): def __init__(self, nodes): self._nodes = list(nodes) - self._defs = dicti() # TODO: lookup table for template elements def transform(self, node_transform): - return Document(node_transform(node, self) for node in self._nodes) + return Document(node_transform(node, self, dicti()) + for node in self._nodes) @classmethod def parse(cls, lines): diff --git a/test.py b/test.py index 03f20f5..27b2d08 100644 --- a/test.py +++ b/test.py @@ -150,15 +150,23 @@ def test_parse_format_identity(self): self.assertEqual(el.c, 99) self.assertEqual(el.E, 101) + def test_deep_lookup(self): + el0 = madseq.Element(None, None, dicti(a='a0', b='b0', c='c0')) + el1 = madseq.Element(None, None, dicti(a='a1', b='b1', d='d1'), el0) + el2 = madseq.Element(None, None, dicti(a='a2'), el1) + self.assertEqual(el2.a, 'a2') + self.assertEqual(el2.b, 'b1') + self.assertEqual(el2.c, 'c0') + self.assertEqual(el2.d, 'd1') + class TestElementTransform(unittest.TestCase): def test_replace_with_parent(self): - base = madseq.Element('BASE', 'DRIFT', dicti(l=1.5)) - elem = madseq.Element(None, 'BASE', dicti()) + base = madseq.Element('BASE', 'DRIFT', dicti(l=1.5, k=2)) + elem = madseq.Element(None, 'BASE', dicti(), base) transformer = madseq.ElementTransform({}) - - tpl, el, l = transformer.replace(elem, 0, 0, base) + tpl, el, l = transformer.replace(elem, 0, 0) self.assertEqual(l, 1.5)