From 3e87e43373de6cf4e8bf1de03f5289a747d65b83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Gl=C3=A4=C3=9Fle?= Date: Thu, 1 May 2014 14:54:38 +0200 Subject: [PATCH 1/8] Add Yaml.load --- madseq.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/madseq.py b/madseq.py index f1d1fee..88753d9 100755 --- a/madseq.py +++ b/madseq.py @@ -600,6 +600,7 @@ def json_adjust_element(elem): #---------------------------------------- class Yaml(object): + def __init__(self): import yaml import pydicti @@ -621,17 +622,23 @@ def _Value_representer(dumper, data): def _Decimal_representer(dumper, data): return dumper.represent_scalar(u'tag:yaml.org,2002:float', str(data).lower()) - Dumper.add_representer(self.dict, _dict_representer) Dumper.add_representer(stri.cls, _stri_representer) Dumper.add_representer(Symbolic, _Value_representer) Dumper.add_representer(Identifier, _Value_representer) Dumper.add_representer(Composed, _Value_representer) Dumper.add_representer(decimal.Decimal, _Decimal_representer) - return yaml.dump(data, stream, Dumper, default_flow_style=False, **kwds) + def load(self, stream, **kwds): + class OrderedLoader(Loader): + pass + OrderedLoader.add_constructor( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, + lambda loader, node: self.dict(loader.construct_pairs(node))) + return yaml.load(stream, OrderedLoader) + #---------------------------------------- # main #---------------------------------------- From 7484e7977c701618611b09ade34a4ca3c76943bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Gl=C3=A4=C3=9Fle?= Date: Thu, 1 May 2014 21:56:38 +0200 Subject: [PATCH 2/8] Make -json, -yaml pure format flags --json and -yaml are now used to set the output format and not to add additional streams as before. Furthermore, the optics (template) elements file can not be set via the SEQUENCE command anymore. --- madseq.py | 99 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 52 insertions(+), 47 deletions(-) diff --git a/madseq.py b/madseq.py index 88753d9..c91195e 100755 --- a/madseq.py +++ b/madseq.py @@ -3,13 +3,13 @@ madseq - MAD-X sequence parser/transformer. Usage: - madseq.py [-j ] [-y ] [-s ] [-m] [-o ] [] + madseq.py [-j|-y] [-s ] [-m] [-o ] [] madseq.py (--help | --version) Options: -o , --output= Set output file - -j , --json= Set JSON output file - -y , --yaml= Set YAML output file + -j, --json Use JSON as output format + -y, --yaml Use YAML as output format -s , --slice= Select slicing -m, --makethin Apply a MAKETHIN like transformation -h, --help Show this help @@ -386,12 +386,14 @@ class Text(str): class Sequence(object): """MadX sequence.""" - def __init__(self, seq): + def __init__(self, seq, opt=None, name=None): + self.name = name + self.opt = opt self.seq = seq def __str__(self, code): """Format sequence to MadX format.""" - return '%s\n'.join(self.seq) + return '\n'.join(map(str, (self.opt or []) + self.seq)) @classmethod def detect(cls, elements): @@ -643,7 +645,7 @@ class OrderedLoader(Loader): # main #---------------------------------------- -def transform(elem, json_file=None, yaml_file=None, +def transform(elem, typecast=Typecast.preserve, slicing=None): """Transform sequence.""" @@ -658,10 +660,6 @@ def transform(elem, json_file=None, yaml_file=None, # TODO: when to slice: explicit/always/never/{select classes} - # select optics routine - optics_file = first.pop('optics', 'inline') - if optics_file == 'inline': - optics_file = None # output method slice_method = getattr(Slice, first.pop('method', 'simple').lower()) @@ -690,30 +688,7 @@ def transform(elem, json_file=None, yaml_file=None, optics.insert(0, Text('! Optics definition for %s:' % first.get('name'))) optics.append(Text()) - if optics_file: - open(optics_file, 'wt').write( - '\n'.join(map(format_element, optics))) - optics = [] - - if json_file: - json.dump( - list(chain.from_iterable(map(json_adjust_element, elems))), - open(json_file, 'wt'), - indent=3, - separators=(',', ' : '), - cls=ValueEncoder) - - if yaml_file: - yaml = Yaml() - yaml.dump( - list(chain.from_iterable(map(json_adjust_element, elems))), - open(yaml_file, 'wt')) - - return (optics + - [Text('! Sequence definition for %s:' % first.name), - first] + - elems + - [last]) + return Sequence([first] + elems + [last], optics, first.get('name')) class File(list): @@ -768,30 +743,60 @@ def main(argv=None): # prepare filters typecast = Typecast.multipole if args['--makethin'] else Typecast.preserve transformation = partial(transform, - json_file=args['--json'], - yaml_file=args['--yaml'], typecast=typecast, slicing=args['--slice']) # perform input - if args['']: + if args[''] and args[''] != '-': with open(args[''], 'rt') as f: input_file = list(f) else: from sys import stdin as input_file - # parse data and apply transformations - original = Sequence.detect(File.parse(input_file)) - processed = chain.from_iterable(map(transformation, original)) - text = "\n".join(map(str, processed)) + # open output stream + if args['--output'] and args['--output'] != '-': + output_file = open(args['--output'], 'wt') + else: + from sys import stdout as output_file + + def load(file): + return Sequence.detect(File.parse(file)) + + def transform(input_data): + return (transformation(el) + for el in input_data) + + def serialize(output_data): + return odicti( + (seq.name, + list(chain.from_iterable(map(json_adjust_element, seq.seq)))) + for seq in output_data + if isinstance(seq, Sequence)) + + if args['--json']: + import json + def dump(output_data): + json.dump( + serialize(output_data), + output_file, + indent=3, + separators=(',', ' : '), + cls=ValueEncoder) + + elif args['--yaml']: + yaml = Yaml() + def dump(output_data): + yaml.dump( + serialize(output_data), + output_file) - # perform output - if args['--output']: - with open(args['--output'], 'wt') as f: - f.write(text) else: - from sys import stdout - stdout.write(text) + def dump(output_data): + output_file.write("\n".join(map(str, output_data))) + + # one line to do it all: + dump(transform(load(input_file))) + main.__doc__ = __doc__ if __name__ == '__main__': From e144a2a459b8124cc06de6a20d8f082c19af41e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Gl=C3=A4=C3=9Fle?= Date: Thu, 1 May 2014 22:01:51 +0200 Subject: [PATCH 3/8] Make file parameter positional --- madseq.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/madseq.py b/madseq.py index c91195e..b4631a4 100755 --- a/madseq.py +++ b/madseq.py @@ -3,11 +3,10 @@ madseq - MAD-X sequence parser/transformer. Usage: - madseq.py [-j|-y] [-s ] [-m] [-o ] [] + madseq.py [-j|-y] [-s ] [-m] [] [] madseq.py (--help | --version) Options: - -o , --output= Set output file -j, --json Use JSON as output format -y, --yaml Use YAML as output format -s , --slice= Select slicing @@ -754,8 +753,8 @@ def main(argv=None): from sys import stdin as input_file # open output stream - if args['--output'] and args['--output'] != '-': - output_file = open(args['--output'], 'wt') + if args[''] and args[''] != '-': + output_file = open(args[''], 'wt') else: from sys import stdout as output_file From 52c08216e4e0236e2e6db0f990baeb0cfd5ecd5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Gl=C3=A4=C3=9Fle?= Date: Thu, 1 May 2014 22:04:50 +0200 Subject: [PATCH 4/8] Remove variables from output --- madseq.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/madseq.py b/madseq.py index b4631a4..3ef466b 100755 --- a/madseq.py +++ b/madseq.py @@ -567,32 +567,11 @@ def loops(offset, refer, elem): # json output #---------------------------------------- -def coeffs_quadrupole(elem): - return (getattr(elem, coeff) - for coeff in ('K1', 'K1S') - if isinstance(elem.get(coeff), Identifier)) - -def get_variables(elem): - coeffs = dicti(quadrupole=coeffs_quadrupole) - - if not (elem.type and elem.at and - elem.type in coeffs): - return [] - - return [('vary', [ - odicti([ - ('name', str(coeff)), - ('step', 1e-6)]) - for coeff in coeffs.get(elem.type)(elem) ] - )] - - def json_adjust_element(elem): if not elem.type: return () return odicti([('name', elem.name), ('type', elem.type)] + - get_variables(elem) + [(k,v) for k,v in elem.args.items() if v is not None]), From 2a57502c3b60758549f4fcc5efd2d8a005272f1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Gl=C3=A4=C3=9Fle?= Date: Thu, 1 May 2014 22:05:44 +0200 Subject: [PATCH 5/8] Add madseq.__version__ information --- madseq.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/madseq.py b/madseq.py index 3ef466b..d6c2f36 100755 --- a/madseq.py +++ b/madseq.py @@ -17,6 +17,8 @@ """ from __future__ import division +__version__ = 'madseq 0.2' + __all__ = [ 'Element', 'Sequence', 'File', 'main' @@ -716,7 +718,7 @@ def __str__(self, code): def main(argv=None): # parse command line options from docopt import docopt - args = docopt(__doc__, argv, version='madseq.py 0.1') + args = docopt(__doc__, argv, version=__version__) # prepare filters typecast = Typecast.multipole if args['--makethin'] else Typecast.preserve From 2ebe9e06bc9dbb557a1964519fbc7bc1d5a11ef3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Gl=C3=A4=C3=9Fle?= Date: Thu, 1 May 2014 22:55:35 +0200 Subject: [PATCH 6/8] Specify slicing options via file The old form to define slicing options inline is now obsolete Also add a preliminary unit test file --- madseq.py | 669 +++++++++++++++++++++++------------------------------- setup.py | 3 +- test.py | 197 ++++++++++++++++ 3 files changed, 486 insertions(+), 383 deletions(-) create mode 100644 test.py diff --git a/madseq.py b/madseq.py index d6c2f36..9c058b4 100755 --- a/madseq.py +++ b/madseq.py @@ -3,110 +3,90 @@ madseq - MAD-X sequence parser/transformer. Usage: - madseq.py [-j|-y] [-s ] [-m] [] [] + madseq.py [-j|-y] [-s ] [] [] madseq.py (--help | --version) Options: -j, --json Use JSON as output format -y, --yaml Use YAML as output format - -s , --slice= Select slicing - -m, --makethin Apply a MAKETHIN like transformation + -s , --slice= Set slicing definition file -h, --help Show this help -v, --version Show version information +Madseq is a MAD-X sequence parser and transformation utility. If called with +only a MAD-X input file, it will look for SEQUENCE..ENDSEQUENCE sections in +the file and update the AT=.. values of all elements. """ -from __future__ import division -__version__ = 'madseq 0.2' -__all__ = [ - 'Element', 'Sequence', 'File', - 'main' -] +from __future__ import division -from pydicti import odicti, dicti +# standard library from itertools import chain from functools import partial -import json import re from math import ceil -from copy import copy -import decimal +from decimal import Decimal, InvalidOperation + +# 3rd-party +from pydicti import odicti, dicti #---------------------------------------- -# utilities +# meta data #---------------------------------------- -def disp(v): - print("%s(%s)" % (v.__class__.__name__, v)) -def cast(type): - """ - Create a simple non-checked constructor. +__version__ = 'madseq 0.2' + +__all__ = [ + 'Element', 'Sequence', 'Document', + 'main' +] - >>> tostr = cast(str) - >>> tostr(None) is None - True - >>> tostr(2) == '2' - True - >>> isinstance(tostr(2), str) - True - """ +#---------------------------------------- +# Utilities +#---------------------------------------- + +def none_checked(type): + """Create a simple ``None``-checked constructor.""" def constructor(value): return None if value is None else type(value) constructor.cls = type return constructor -@cast -class stri(str): - """ - String with case insensitive comparison. - - >>> stri("HeLLo") == "helLO" and "HeLLo" == stri("helLO") - True - >>> stri("wOrLd") != "WOrld" or "wOrLd" != stri("WOrld") - False - >>> stri("HeLLo") == "WOrld" or "HeLLo" == stri("WOrld") - False - >>> stri("wOrLd") != "helLO" and "wOrLd" != stri("helLO") - True - >>> stri("HEllO wORld") - HEllO wORld - """ +@none_checked +class stri(str): + """Case insensitive string.""" def __eq__(self, other): return self.lower() == str(other).lower() def __ne__(self, other): return not (self == other) + class Re(object): - """ - Precompiled regular expressions. - >>> r1 = Re('hello') - >>> r2 = Re(r1, 'world') - >>> assert(r1.search(' helloworld ')) - >>> assert(not r1.search('yelloworld')) - >>> assert(r2.match('helloworld anything')) - >>> assert(not r2.match(' helloworld anything')) + """Precompiled regular expressions.""" - """ def __init__(self, *args): """Concat the arguments.""" self.s = ''.join(map(str, args)) self.r = re.compile(self.s) def __str__(self): - """Display as the expression that was used to create the regex.""" + """Return the expression that was used to create the regex.""" return self.s def __getattr__(self, key): """Delegate attribute access to the precompiled regex.""" return getattr(self.r, key) + class regex(object): + """List of regular expressions used in this script.""" + integer = Re(r'(?:\d+)') number = Re(r'(?:[+\-]?(?:\d+(?:\.\d*)?|\d*\.\d+)(?:[eE][+\-]?\d+)?)') thingy = Re(r'(?:[^\s,;!]+)') @@ -119,50 +99,39 @@ class regex(object): comment_split = Re(r'^([^!]*)(!.*)?$') - slice_per_m = Re(r'^(',number,r')\/m$') - is_string = Re(r'^\s*(?:"([^"]*)")\s*$') is_identifier = Re(r'^\s*(',identifier,')\s*$') + #---------------------------------------- # Line model + parsing + formatting #---------------------------------------- -class fakefloat(float): - """Used to serialize decimal.Decimal. - See: http://stackoverflow.com/a/8274307/650222""" - def __init__(self, value): - self._value = value - def __repr__(self): - return str(self._value) -class ValueEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, Value): - return obj.value - if isinstance(obj, decimal.Decimal): - return fakefloat(obj.normalize()) - # Let the base class default method raise the TypeError - return json.JSONEncoder.default(self, obj) def fmtArg(value): - if isinstance(value, decimal.Decimal): + if isinstance(value, Decimal): return '=' + str(value.normalize()) elif isinstance(value, str): return '="%s"' % value elif isinstance(value, (float, int)): return '=' + str(value) + elif isinstance(value, (tuple, list)): + return '={%s}' % ','.join(map(fmtInner, value)) else: return value.fmtArg() + def fmtInner(value): - if isinstance(value, decimal.Decimal): + if isinstance(value, Decimal): return str(value.normalize()) try: return value.fmtInner() except AttributeError: return str(value) + class Value(object): + def __init__(self, value, assign='='): self.value = value self.assign = assign @@ -176,68 +145,38 @@ def fmtArg(self): @classmethod def parse(cls, text, assign='='): try: - return Number.parse(text) + return parse_number(text) except ValueError: try: - return String.parse(text) + return parse_string(text) except ValueError: return Symbolic.parse(text, assign) -class Number(object): - """ - Used to parse numberic values. - """ - @classmethod - def parse(cls, text): - """ - Parse numeric value. - - >>> disp(Number.parse('-13')) - int(-13) - >>> disp(Number.parse('12.')) - float(12.0) - >>> disp(Number.parse('1.2e1')) - float(12.0) - """ +def parse_number(text): + """Parse numeric value as :class:`int` or :class:`Decimal`.""" + try: + return int(text) + except ValueError: try: - return int(text) - except ValueError: - try: - return decimal.Decimal(text) - except decimal.InvalidOperation: - raise ValueError("Not a floating point: %s" % text) + return Decimal(text) + except InvalidOperation: + raise ValueError("Not a floating point: %s" % text) + -class String(object): +@none_checked +def parse_string(text): """Used to parse string values.""" - @classmethod - def parse(cls, text): - if text is None: - return None - try: - return regex.is_string.match(str(text)).groups()[0] - except AttributeError: - raise ValueError("Invalid string: %s" % (text,)) + try: + return regex.is_string.match(str(text)).groups()[0] + except AttributeError: + raise ValueError("Invalid string: %s" % (text,)) class Symbolic(Value): - """ - Symbolic value. - - >>> i = Number.parse('-13') - >>> f = Number.parse('12.') - >>> s = Symbolic.parse('pi') - >>> disp(f + s) - Composed(12.0 + pi) - >>> disp(s - f) - Composed(pi - 12.0) - >>> disp(i + f * s) - Composed(-13 + (12.0 * pi)) - >>> disp(s / s) - Composed(pi / pi) + """Base class for identifiers and composed arithmetic expressions.""" - """ @classmethod def parse(cls, text, assign=False): try: @@ -273,8 +212,11 @@ def parse(cls, text, assign='='): except AttributeError: raise ValueError("Invalid identifier: %s" % (text,)) + class Composed(Symbolic): + """Composed value.""" + @classmethod def parse(cls, text, assign='='): return cls(text, assign) @@ -294,12 +236,14 @@ def parse_args(text): return odicti((key, Value.parse(val, assign)) for key,assign,val in regex.arg.findall(text or '')) + class Element(object): + """ Single MAD-X element. """ - # TODO: json - __slots__ = ['name', 'type', 'args', 'slice', 'slice_len', 'slice_num'] + + __slots__ = ['name', 'type', 'args'] def __init__(self, name, type, args): """ @@ -308,7 +252,6 @@ def __init__(self, name, type, args): :param str name: name of the element (colon prefix) :param str type: command name or element type :param dict args: command arguments - """ self.name = stri(name) self.type = stri(type) @@ -316,20 +259,13 @@ def __init__(self, name, type, args): @classmethod def parse(cls, text): - """ - Parse element from string. - - >>> mad = "name: type, a=97, b=98, c=99, d=100, e=101;" - >>> el = Element.parse(mad) - >>> str(mad) == mad - True - >>> el.c, el.E - (99, 101) - - """ + """Parse element from MAD-X string.""" name, type, args = regex.cmd.match(text).groups() return Element(name, type, parse_args(args)) + def copy(self): + return self.__class__(self.name, self.type, self.args.copy()) + def __getattr__(self, key): return self.args[key] @@ -352,13 +288,7 @@ def pop(self, key, *default): return self.args.pop(key, *default) def __str__(self): - """ - Output element in MAD-X format. - - >>> str(Element('name', 'type', odicti(zip("abcde", range(5))))) - 'name: type, a=0, b=1, c=2, d=3, e=4;' - - """ + """Output element in MAD-X format.""" def _fmt_arg(k, v): return ', %s' % k if v is None else ', %s%s' % (k,fmtArg(v)) return '%s%s%s;' % ('%s: ' % self.name if self.name else '', @@ -366,34 +296,27 @@ def _fmt_arg(k, v): ''.join(_fmt_arg(k, v) for k,v in self.args.items())) - def __repr__(self): - """ - Representation, mainly used to write tests. - - >>> repr(Element('name', 'type', odicti(zip("abcde", range(5))))) - "Element('name', 'type', a=0, b=1, c=2, d=3, e=4)" + def __eq__(self, other): + return (self.name == other.name and + self.type == other.type and + self.args == other.args) - """ - def _fmt_arg(k, v): - return ', %s' % k if v is None else ', %s%s' % (k,fmtArg(v)) - return '%s(%r, %r%s)' % ( - self.__class__.__name__, - self.name, - self.type, - ''.join(_fmt_arg(k, v) for k,v in self.args.items())) class Text(str): type = None + class Sequence(object): - """MadX sequence.""" + + """MAD-X sequence.""" + def __init__(self, seq, opt=None, name=None): self.name = name self.opt = opt self.seq = seq - def __str__(self, code): - """Format sequence to MadX format.""" + def __str__(self): + """Format sequence to MAD-X format.""" return '\n'.join(map(str, (self.opt or []) + self.seq)) @classmethod @@ -415,161 +338,133 @@ def detect(cls, elements): # Transformations #---------------------------------------- -def filter_default(offset, refer, elem): - elem.at = offset + refer*elem.get('L', 0) - return [elem] - -def detect_slicing(elem, default_slicing): - # fall through for elements without explicit slice attribute - slicing = elem.pop('slice', default_slicing) - if not slicing: - return None - elem_len = elem.get('L', 0) - if elem_len == 0: - return None - - # determine slice number, length - if isinstance(slicing, int): - elem.slice_num = slicing - elem.slice_len = elem_len / slicing - else: - m = regex.slice_per_m.match(slicing) - if not m: - raise ValueError("Invalid slicing: %s" % slicing) - slice_per_m = decimal.Decimal(m.groups()[0]) - elem.slice_num = int(ceil(abs(elem_len * slice_per_m))) - elem.slice_len = elem_len / elem.slice_num - - # replace L property - elem.L = elem.slice_len - return elem +def exclusive(mapping, *keys): + return sum(key in mapping for key in keys) <= 1 -class Typecast(object): - """ - Namespace for MAKETHIN-like transformations. - """ - @staticmethod - def preserve(elem): - """ - Leave the elements 'as is'. - Leave their type unchanged for later transformation in - MADX.MAKETHIN. +class Transform(object): - >>> el = Element(None, 'SBEND', {'angle': 3.14, 'slice_num': 2}) - >>> Typecast.preserve(el) - >>> el.angle - 1.57 - >>> el.type - 'SBEND' + def __init__(self, selector): - """ - if elem.type == 'sbend' and elem.slice_num > 1: - elem.angle = elem.angle / elem.slice_num + # matching criterium + exclusive(selector, 'name', 'type') + if 'name' in selector: + name = selector['name'] + self.match = lambda elem: elem.name == name + elif 'type' in selector: + type = selector['type'] + self.match = lambda elem: elem.type == type + else: + self.match = lambda elem: True - @staticmethod - def multipole(elem): - """ - Transform the elements to MULTIPOLE elements. - - NOTE: Typecast.multipole is currently not recommended! If you use - it, you have to make sure, your slice length will be sufficiently - small! You should use Mad-X' MAKETHIN or not use it at all! - - >>> sbend = Element(None, 'SBEND', dicti(angle=3.14, slice_num=2, - ... hgap=1, L=3.5)) - >>> Typecast.multipole(sbend) - >>> sbend.KNL - '{1.57}' - >>> sbend.get('angle') - >>> sbend.get('hgap') #TODO: HGAP is just dropped ATM! - >>> sbend.type - 'multipole' - - >>> quad = Element(None, 'QUADRUPOLE', dicti(K1=3, slice_num=2, - ... L=2.5)) - >>> Typecast.multipole(quad) - >>> quad.KNL - '{0, 7.5}' - >>> quad.get('K1') - >>> quad.type - 'multipole' + # number of slices per element + exclusive(selector, 'density', 'slice') + if 'density' in selector: + density = selector['density'] + self._get_slice_num = lambda L: int(ceil(abs(L * density))) + else: + slice_num = selector.get('slice', 1) + self._get_slice_num = lambda L: slice_num - """ - if elem.type == 'sbend': - elem.KNL = '{%s}' % (elem.angle / elem.slice_num,) - del elem.angle - del elem.HGAP - elif elem.type == 'quadrupole': - elem.KNL = '{0, %s}' % (elem.K1 * elem.L,) - del elem.K1 - elif elem.type == 'solenoid': - elem.ksi = elem.KS / elem.slice_num - elem.lrad = elem.L / elem.slice_num - elem.L = 0 - return + # rescale elements + if selector.get('makethin', False): + self._rescale = rescale_makethin else: - return - - # set elem_class to multipole - elem.type = stri('multipole') - # replace L by LRAD property - elem.lrad = elem.pop('L', None) - - - -class Slice(object): - - @staticmethod - def simple(offset, refer, elem): - elems = [] - for slice_idx in range(elem.slice_num): - slice = Element(None, elem.type, copy(elem.args)) - if elem.name: - slice.name = "%s..%s" % (elem.name, slice_idx) - slice.at = offset + (slice_idx + refer)*elem.slice_len - elems.append(slice) - return None, elems - - @staticmethod - def optics(offset, refer, elem): - return elem, [ - Element("%s..%s" % (elem.name, slice_idx), - elem.name, - odicti(at=offset + (slice_idx + refer)*elem.slice_len)) - for slice_idx in range(elem.slice_num) ] - - @staticmethod - def optics_short(offset, refer, elem): - return elem, [ - Element( - None, - elem.name, - odicti(at=offset + (slice_idx + refer)*elem.slice_len)) - for slice_idx in range(elem.slice_num) ] - - @staticmethod - def loops(offset, refer, elem): - return elem, [ - Text('i = 0;'), - Text('while (i < %s) {' % elem.slice_num), - Element( - None, - elem.name, - odicti( - at=(offset + - (Identifier('i',True) + refer) * - elem.get('L', elem.get('lrad'))) - )), - Text('i = i + 1;'), - Text('}'), ] + self._rescale = rescale_thick + + # whether to use separate optics + if selector.get('use_optics', False): + def make_optic(elem, elem_len, slice_num): + optic = elem.copy() + optic.L = elem_len / slice_num + return [optic] + self._makeoptic = make_optic + self._stripelem = lambda elem: Element(None, elem.name, {}) + else: + self._makeoptic = lambda elem, slice_num: [] + self._stripelem = lambda elem: elem + + # slice distribution style over element length + style = selector.get('style', 'uniform') + if style == 'uniform': + self._distribution = self.uniform_slice_distribution + elif style == 'loop': + self._distribution = self.uniform_slice_loop + else: + raise ValueError("Unknown slicing style: {!r}".format(style)) + + def replace(self, elem, offset, refer): + elem_len = elem.get('L', 0) + slice_num = self._get_slice_num(elem_len) or 1 + optic = self._makeoptic(elem, slice_num) + elem = self._stripelem(elem) + elems = self._distribution(elem, offset, refer, elem_len, slice_num) + return optic, elems, elem_len + + def uniform_slice_distribution(self, elem, offset, refer, elem_len, slice_num): + slice_len = Decimal(elem_len) / slice_num + scaled = self._rescale(elem, 1/Decimal(slice_num)) + for slice_idx in range(slice_num): + slice = scaled.copy() + slice.at = offset + (slice_idx + refer)*slice_len + yield slice + + def uniform_slice_loop(self, elem, offset, refer, elem_len, slice_num): + slice_len = elem_len / slice_num + slice = self._rescale(elem, 1/Decimal(slice_num)).copy() + slice.at = offset + (Identifier('i', True) + refer) * slice_len + yield Text('i = 0;') + yield Text('while (i < %s) {' % slice_num) + yield slice + yield Text('i = i + 1;') + yield Text('}') + + +def rescale_thick(elem, ratio): + """Shrink/grow element size, while leaving the element type 'as is'.""" + if ratio == 1: + return elem + scaled = elem.copy() + scaled.L = elem.L * ratio + if scaled.type == 'sbend': + scaled.angle = scaled.angle * ratio + return scaled + + +def rescale_makethin(elem, ratio): + """ + Shrink/grow element size, while transforming elements to MULTIPOLEs. + + NOTE: rescale_makethin is currently not recommended! If you use it, + you have to make sure, your slice length will be sufficiently small! + """ + if elem.type not in ('sbend', 'quadrupole', 'solenoid'): + return elem + elem = elem.copy() + if elem.type == 'sbend': + elem.KNL = [elem.angle * ratio] + del elem.angle + del elem.HGAP + elif elem.type == 'quadrupole': + elem.KNL = [0, elem.K1 * elem.L] + del elem.K1 + elif elem.type == 'solenoid': + elem.ksi = elem.KS * ratio + elem.lrad = elem.L * ratio + elem.L = 0 + return + # set elem_class to multipole + elem.type = stri('multipole') + # replace L by LRAD property + elem.lrad = elem.pop('L', None) + return elem #---------------------------------------- -# json output +# JSON/YAML serialization #---------------------------------------- -def json_adjust_element(elem): +def _adjust_element(elem): if not elem.type: return () return odicti([('name', elem.name), @@ -577,9 +472,47 @@ def json_adjust_element(elem): [(k,v) for k,v in elem.args.items() if v is not None]), -#---------------------------------------- -# YAML serialization utility: -#---------------------------------------- +def _getstate(output_data): + return odicti( + (seq.name, odicti( + list(seq.seq[0].args.items()) + + [('elements', list(chain.from_iterable(map(_adjust_element, + seq.seq[1:-1])))),] + )) + for seq in output_data + if isinstance(seq, Sequence)) + + +class Json(object): + + def __init__(self): + import json + self.json = json + + def dump(self, data, stream): + + class fakefloat(float): + """Used to serialize Decimal. + See: http://stackoverflow.com/a/8274307/650222""" + def __init__(self, value): + self._value = value + def __repr__(self): + return str(self._value) + + class ValueEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Value): + return obj.value + if isinstance(obj, Decimal): + return fakefloat(obj.normalize()) + # Let the base class default method raise the TypeError + return json.JSONEncoder.default(self, obj) + + self.json.dump(data, stream, + indent=2, + separators=(',', ' : '), + cls=ValueEncoder) + class Yaml(object): @@ -589,7 +522,7 @@ def __init__(self): self.yaml = yaml self.dict = pydicti.odicti - def dump(self, data, stream=None, **kwds): + def dump(self, data, stream=None): yaml = self.yaml class Dumper(yaml.SafeDumper): pass @@ -609,12 +542,12 @@ def _Decimal_representer(dumper, data): Dumper.add_representer(Symbolic, _Value_representer) Dumper.add_representer(Identifier, _Value_representer) Dumper.add_representer(Composed, _Value_representer) - Dumper.add_representer(decimal.Decimal, _Decimal_representer) - return yaml.dump(data, stream, Dumper, - default_flow_style=False, **kwds) + Dumper.add_representer(Decimal, _Decimal_representer) + return yaml.dump(data, stream, Dumper, default_flow_style=False) - def load(self, stream, **kwds): - class OrderedLoader(Loader): + def load(self, stream): + yaml = self.yaml + class OrderedLoader(yaml.SafeLoader): pass OrderedLoader.add_constructor( yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, @@ -625,52 +558,50 @@ class OrderedLoader(Loader): # main #---------------------------------------- -def transform(elem, - typecast=Typecast.preserve, - slicing=None): +def transform(elem, slicing): + """Transform sequence.""" + if not isinstance(elem, Sequence): - return (elem,) + return elem seq = elem first = seq.seq[0] last = seq.seq[-1] - offsets = dicti(entry=0, centre=decimal.Decimal(0.5), exit=1) + offsets = dicti(entry=0, centre=Decimal(0.5), exit=1) refer = offsets[str(first.get('refer', 'centre'))] - # TODO: when to slice: explicit/always/never/{select classes} + # create slicer + transforms = [Transform(s) for s in slicing] + [] + transforms.append(Transform({})) + def transform(elem, offset): + for t in transforms: + if t.match(elem): + return t.replace(elem, offset, refer) - # output method - slice_method = getattr(Slice, first.pop('method', 'simple').lower()) + templates = [] # predefined element templates + elements = [] # actual elements to put in sequence + position = 0 # current element position - # iterate through sequence - elems = [] - optics = [] - length = 0 for elem in seq.seq[1:-1]: if elem.type: - elem_len = elem.get('L', 0) - if detect_slicing(elem, slicing): - typecast(elem) - optic, elem = slice_method(length, refer, elem) - if optic: - optics.append(optic) - elems += elem - else: - elems += filter_default(length, refer, elem) - length += elem_len + optic, elem, elem_len = transform(elem, position) + templates += optic + elements += elem + position += elem_len else: - elems.append(elem) - first.L = length + elements.append(elem) + first.L = position + + if templates: + templates.insert(0, Text('! Template elements for %s:' % first.get('name'))) + templates.append(Text()) - if optics: - optics.insert(0, Text('! Optics definition for %s:' % first.get('name'))) - optics.append(Text()) + return Sequence([first] + elements + [last], templates, first.name) - return Sequence([first] + elems + [last], optics, first.get('name')) -class File(list): +class Document(list): @classmethod def parse(cls, lines): @@ -680,21 +611,11 @@ def parse(cls, lines): @classmethod def parse_line(cls, line): """ - Parse a single-line MAD-X statement. + Parse a single-line MAD-X input statement. Return an iterable that iterates over parsed elements. - TODO: multi-line commands! - - >>> list(File.parse_line(' \t ')) - [''] - - >>> list(File.parse_line(' \t ! a comment; ! ')) - ['! a comment; ! '] - - >>> list(File.parse_line(' use, hello=world, z=23.23e2; k: z; !')) - ['!', Element(None, 'use', hello=world, z=2323.0), Element('k', 'z')] - + TODO: Does not support multi-line commands yet! """ code, comment = regex.comment_split.match(line).groups() if comment: @@ -711,21 +632,13 @@ def parse_line(cls, line): if len(commands) == 1 and not comment: yield Text('') - def __str__(self, code): - """Format sequence to MadX format.""" - pass def main(argv=None): + # parse command line options from docopt import docopt args = docopt(__doc__, argv, version=__version__) - # prepare filters - typecast = Typecast.multipole if args['--makethin'] else Typecast.preserve - transformation = partial(transform, - typecast=typecast, - slicing=args['--slice']) - # perform input if args[''] and args[''] != '-': with open(args[''], 'rt') as f: @@ -739,44 +652,36 @@ def main(argv=None): else: from sys import stdout as output_file + # get slicing definition + if args['--slice']: + with open(args['--slice']) as f: + transforms_doc = Yaml().load(f) + else: + transforms_doc = [] + def load(file): - return Sequence.detect(File.parse(file)) + return Sequence.detect(Document.parse(file)) - def transform(input_data): - return (transformation(el) + def edit(input_data): + return (transform(el, slicing=transforms_doc) for el in input_data) - def serialize(output_data): - return odicti( - (seq.name, - list(chain.from_iterable(map(json_adjust_element, seq.seq)))) - for seq in output_data - if isinstance(seq, Sequence)) - if args['--json']: - import json + json = Json() def dump(output_data): - json.dump( - serialize(output_data), - output_file, - indent=3, - separators=(',', ' : '), - cls=ValueEncoder) + json.dump(_getstate(output_data), output_file) elif args['--yaml']: yaml = Yaml() def dump(output_data): - yaml.dump( - serialize(output_data), - output_file) + yaml.dump(_getstate(output_data), output_file) else: def dump(output_data): output_file.write("\n".join(map(str, output_data))) # one line to do it all: - dump(transform(load(input_file))) - + dump(edit(load(input_file))) main.__doc__ = __doc__ if __name__ == '__main__': diff --git a/setup.py b/setup.py index 6d56f0e..77e3e84 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,8 @@ 'docopt' ], extras_require={ - 'yaml': ['pyyaml'] + 'yaml': ['PyYAML'], + 'slice': ['PyYAML'] }, test_suite='nose.collector', tests_require=['nose'], diff --git a/test.py b/test.py new file mode 100644 index 0000000..ec6b6e6 --- /dev/null +++ b/test.py @@ -0,0 +1,197 @@ +# test utilities +import unittest + +from decimal import Decimal + +from pydicti import dicti + +# tested module +import madseq + + +class TestUtils(unittest.TestCase): + + def test_none_checked(self): + tostr = madseq.none_checked(str) + self.assertIs(tostr(None), None) + self.assertEqual(tostr(1), '1') + + def test_stri(self): + stri = madseq.stri + self.assertEqual(stri("HeLLo"), "helLO") + self.assertEqual("HeLLo", stri("helLO")) + self.assertNotEqual(stri("HeLLo"), "WOrld") + self.assertNotEqual("HeLLo", stri("WOrld")) + s = "HEllO wORld" + self.assertEqual('%s' % (stri(s),), s) + + def test_Re(self): + Re = madseq.Re + r1 = Re('hello') + r2 = Re(r1, 'world') + self.assertTrue(r1.search(' helloworld ')) + self.assertFalse(r1.search('yelloworld')) + self.assertTrue(r2.match('helloworld anything')) + self.assertFalse(r2.match(' helloworld anything')) + +class test_Parse(unittest.TestCase): + + def test_Number(self): + parse = madseq.parse_number + self.assertEqual(parse('-13'), -13) + self.assertEqual(parse('1.2e1'), Decimal('12.0')) + + def test_Document_parse_line(self): + + parse = madseq.Document.parse_line + Element = madseq.Element + + self.assertEqual(list(parse(' \t ')), + ['']) + + self.assertEqual(list(parse(' \t ! a comment; ! ')), + ['! a comment; ! ']) + + self.assertEqual(list(parse(' use, z=23.23e2; k: z; !')), + ['!', + Element(None, 'use', {'z': Decimal('23.23e2')}), + Element('k', 'z', {})]) + + def test_Symbolic(self): + + i = madseq.parse_number('-13') + f = madseq.parse_number('12.') + s = madseq.Symbolic.parse('pi') + + self.assertEqual(str(f + s), "12 + pi") + self.assertEqual(str(s - f), "pi - 12") + self.assertEqual(str(i + f * s), "-13 + (12 * pi)") + self.assertEqual(str(s / s), "pi / pi") + + + +class test_regex(unittest.TestCase): + + def setUp(self): + name = self._testMethodName.split('_', 1)[1] + reg = str(getattr(madseq.regex, name)).lstrip('^').rstrip('$') + self.r = madseq.Re('^', reg , '$') + + def test_number(self): + self.assertTrue(self.r.match('23')) + self.assertTrue(self.r.match('23.0')) + self.assertTrue(self.r.match('-1e+1')) + self.assertTrue(self.r.match('+2e-3')) + self.assertFalse(self.r.match('')) + self.assertFalse(self.r.match('e.')) + self.assertFalse(self.r.match('.e')) + + def test_thingy(self): + self.assertTrue(self.r.match('unseparated')) + self.assertTrue(self.r.match('23')) + self.assertTrue(self.r.match('23.0')) + self.assertTrue(self.r.match('-1e+1')) + self.assertTrue(self.r.match('+2e-3')) + self.assertFalse(self.r.match('e;')) + self.assertFalse(self.r.match('e,')) + self.assertFalse(self.r.match(' e')) + self.assertFalse(self.r.match('e!')) + self.assertTrue(self.r.match('"a.1"')) + + def test_identifier(self): + self.assertTrue(self.r.match('a')) + self.assertTrue(self.r.match('a1')) + self.assertTrue(self.r.match('a.1')) + self.assertFalse(self.r.match('')) + self.assertFalse(self.r.match('1a')) + + def test_string(self): + self.assertTrue(self.r.match('"hello world"')) + self.assertTrue(self.r.match('"hello !,; world"')) + self.assertFalse(self.r.match('"foo" bar"')) + self.assertFalse(self.r.match('foo" bar"')) + self.assertFalse(self.r.match('')) + + def test_param(self): + self.assertTrue(self.r.match('unseparated')) + self.assertTrue(self.r.match('23')) + self.assertTrue(self.r.match('23.0')) + self.assertTrue(self.r.match('-1e+1')) + self.assertTrue(self.r.match('+2e-3')) + self.assertTrue(self.r.match('"hello world"')) + self.assertFalse(self.r.match('"foo" bar"')) + self.assertFalse(self.r.match('foo" bar"')) + self.assertFalse(self.r.match('')) + self.assertFalse(self.r.match('e;')) + self.assertFalse(self.r.match('e,')) + self.assertFalse(self.r.match(' e')) + self.assertFalse(self.r.match('e!')) + + def test_cmd(self): + pass + + def test_arg(self): + pass + def test_comment_split(self): + pass + def test_is_string(self): + pass + def test_is_identifier(self): + pass + + + +class TestElement(unittest.TestCase): + + def test_parse_format_identity(self): + mad = "name: type, a=97, b=98, c=99, d=100, e=101;" + el = madseq.Element.parse(mad) + self.assertEqual(str(mad), mad) + self.assertEqual(el.c, 99) + self.assertEqual(el.E, 101) + + + +class TestSequence(unittest.TestCase): + pass + + +class TestMakethin(unittest.TestCase): + + def test_rescale_makethin_sbend(self): + sbend = madseq.Element(None, + 'SBEND', dicti(angle=3.14, hgap=1, L=3.5)) + scaled = madseq.rescale_makethin(sbend, 0.5) + self.assertEqual(scaled.KNL, [1.57]) + self.assertEqual(scaled.get('angle'), None) + self.assertEqual(scaled.get('hgap'), None) + self.assertEqual(scaled.type, 'multipole') + + def test_rescale_makethin_quadrupole(self): + quad = madseq.Element(None, 'QUADRUPOLE', dicti(K1=3, L=2.5)) + scaled = madseq.rescale_makethin(quad, 0.5) + self.assertEqual(scaled.KNL, [0, 7.5]) + self.assertEqual(scaled.get('K1'), None) + self.assertEqual(scaled.type, 'multipole') + + def test_rescale_thick(self): + pi = 3.14 + el = madseq.Element(None, 'SBEND', {'angle': pi, 'L': 1}) + scaled = madseq.rescale_thick(el, 0.5) + self.assertEqual(scaled.angle, pi/2) + self.assertEqual(scaled.type, 'SBEND') + + +class TestSlice(unittest.TestCase): + pass + +class TestFile(unittest.TestCase): + pass + +class TestJson(unittest.TestCase): + pass + +# execute tests if this file is invoked directly: +if __name__ == '__main__': + unittest.main() + From d68d7d735f3d6dc21decefee64222a459d94aa24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Gl=C3=A4=C3=9Fle?= Date: Fri, 2 May 2014 01:12:29 +0200 Subject: [PATCH 7/8] Parse arrays into new data type --- madseq.py | 27 +++++++++++++++++++++++++-- setup.py | 2 +- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/madseq.py b/madseq.py index 9c058b4..c8d7c42 100755 --- a/madseq.py +++ b/madseq.py @@ -140,7 +140,7 @@ def __str__(self): return str(self.value) def fmtArg(self): - return self.assign + str(self.value) + return self.assign + str(self) @classmethod def parse(cls, text, assign='='): @@ -150,7 +150,10 @@ def parse(cls, text, assign='='): try: return parse_string(text) except ValueError: - return Symbolic.parse(text, assign) + try: + return Array.parse(text, assign) + except ValueError: + return Symbolic.parse(text, assign) def parse_number(text): @@ -173,6 +176,26 @@ def parse_string(text): raise ValueError("Invalid string: %s" % (text,)) +class Array(Value): + + @classmethod + def parse(cls, text, assign=False): + """Parse a MAD-X array.""" + if text[0] != '{': + raise ValueError("Invalid array: %s" % (text,)) + if text[-1] != '}': + raise Exception("Array not terminated correctly: %s" % (text,)) + try: + return cls([Value.parse(field.strip(), assign) + for field in text[1:-1].split(',')], + assign) + except ValueError: + raise Exception("Array not well-formed: %s" % (text,)) + + def __str__(self): + return '{' + ','.join(map(str, self.value)) + '}' + + class Symbolic(Value): """Base class for identifiers and composed arithmetic expressions.""" diff --git a/setup.py b/setup.py index 77e3e84..74293de 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ setup( name='madseq', - version='0.2', + version='0.3', description='Parser/transformator for MAD-X sequences', long_description=long_description, author='Thomas Gläßle', From 34b53c71ce97a0df8dfb2d0838e8704fe5dd4ce4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Gl=C3=A4=C3=9Fle?= Date: Fri, 2 May 2014 01:36:54 +0200 Subject: [PATCH 8/8] Fix tests --- .travis.yml | 1 - test.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 0161810..6c37e48 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,6 @@ language: python python: - "2.6" - "2.7" - - "3.2" - "3.3" - "pypy" install: diff --git a/test.py b/test.py index ec6b6e6..6be13d4 100644 --- a/test.py +++ b/test.py @@ -13,7 +13,7 @@ class TestUtils(unittest.TestCase): def test_none_checked(self): tostr = madseq.none_checked(str) - self.assertIs(tostr(None), None) + self.assertEqual(tostr(None), None) self.assertEqual(tostr(1), '1') def test_stri(self):