Skip to content
This repository has been archived by the owner on Aug 29, 2020. It is now read-only.

Commit

Permalink
Support deep attribute lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
coldfix committed May 3, 2014
1 parent 1911fa1 commit 9eb4d5b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Changelog
=====

- use `semantic versioning <http://semver.org/>`
- fix deep attribute lookup for elements


0.3
Expand Down
46 changes: 33 additions & 13 deletions madseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class Element(object):

__slots__ = ['name', 'type', 'args']

def __init__(self, name, type, args):
def __init__(self, name, type, args, base=None):
"""
Initialize an Element object.
Expand All @@ -279,6 +279,7 @@ def __init__(self, name, type, args):
self.name = stri(name)
self.type = stri(type)
self.args = args
self._base = base

@classmethod
def parse(cls, text):
Expand All @@ -290,7 +291,7 @@ def copy(self):
return self.__class__(self.name, self.type, self.args.copy())

def __getattr__(self, key):
return self.args[key]
return self[key]

def __setattr__(self, key, val):
if key in self.__slots__:
Expand All @@ -304,11 +305,30 @@ def __contains__(self, key):
def __delattr__(self, key):
del self.args[key]

def __getitem__(self, key):
try:
return self.args[key]
except KeyError:
if self._base:
return self._base[key]
raise

def get(self, key, default=None):
return self.args.get(key, default)
try:
return self[key]
except KeyError:
return default

def pop(self, key, *default):
return self.args.pop(key, *default)
try:
return self.args.pop(key)
except KeyError:
try:
return self._base[key]
except (KeyError, TypeError):
if default:
return default[0]
raise

def __str__(self):
"""Output element in MAD-X format."""
Expand Down Expand Up @@ -373,9 +393,9 @@ def __init__(self, slicing):
self.transforms = [ElementTransform(s) for s in slicing] + []
self.transforms.append(ElementTransform({}))

def __call__(self, node, document):
def __call__(self, node, defs):
if isinstance(node, (Element, Sequence)):
document._defs[node.name] = node
defs[node.name] = node

if not isinstance(node, Sequence):
return node
Expand All @@ -386,9 +406,11 @@ def __call__(self, node, document):
refer = self.offsets[str(first.get('refer', 'centre'))]

def transform(elem, offset):
if elem.type:
elem._base = defs.get(elem.type)
for t in self.transforms:
if t.match(elem):
return t.replace(elem, offset, refer, document._defs.get(elem.type))
return t.replace(elem, offset, refer)

templates = [] # predefined element templates
elements = [] # actual elements to put in sequence
Expand Down Expand Up @@ -463,10 +485,8 @@ def make_optic(elem, elem_len, slice_num):
else:
raise ValueError("Unknown slicing style: {!r}".format(style))

def replace(self, elem, offset, refer, parent):
elem_len = elem.get('L')
if elem_len is None:
elem_len = parent.get('L', 0) if parent else 0
def replace(self, elem, offset, refer):
elem_len = elem.get('L', 0)
slice_num = self._get_slice_num(elem_len) or 1
optic = self._makeoptic(elem, slice_num)
elem = self._stripelem(elem)
Expand Down Expand Up @@ -628,11 +648,11 @@ class Document(list):

def __init__(self, nodes):
self._nodes = list(nodes)
self._defs = dicti()
# TODO: lookup table for template elements

def transform(self, node_transform):
return Document(node_transform(node, self) for node in self._nodes)
return Document(node_transform(node, self, dicti())
for node in self._nodes)

@classmethod
def parse(cls, lines):
Expand Down
16 changes: 12 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,23 @@ def test_parse_format_identity(self):
self.assertEqual(el.c, 99)
self.assertEqual(el.E, 101)

def test_deep_lookup(self):
el0 = madseq.Element(None, None, dicti(a='a0', b='b0', c='c0'))
el1 = madseq.Element(None, None, dicti(a='a1', b='b1', d='d1'), el0)
el2 = madseq.Element(None, None, dicti(a='a2'), el1)
self.assertEqual(el2.a, 'a2')
self.assertEqual(el2.b, 'b1')
self.assertEqual(el2.c, 'c0')
self.assertEqual(el2.d, 'd1')


class TestElementTransform(unittest.TestCase):

def test_replace_with_parent(self):
base = madseq.Element('BASE', 'DRIFT', dicti(l=1.5))
elem = madseq.Element(None, 'BASE', dicti())
base = madseq.Element('BASE', 'DRIFT', dicti(l=1.5, k=2))
elem = madseq.Element(None, 'BASE', dicti(), base)
transformer = madseq.ElementTransform({})

tpl, el, l = transformer.replace(elem, 0, 0, base)
tpl, el, l = transformer.replace(elem, 0, 0)
self.assertEqual(l, 1.5)


Expand Down

0 comments on commit 9eb4d5b

Please sign in to comment.