diff --git a/gammapy/maps/geom.py b/gammapy/maps/geom.py index 2ea5978345..5a42526551 100644 --- a/gammapy/maps/geom.py +++ b/gammapy/maps/geom.py @@ -45,34 +45,40 @@ def make_axes_cols(axes, axis_names=None): Parameters ---------- - axes : list of `~MapAxis` - - axis_names : list of str + axes : list + Python list of `MapAxis` objects + Returns + ------- + cols : list + Python list of `~astropy.io.fits.Column` """ - colname = { - 'energy': ['ENERGY', 'E_MIN', 'E_MAX'], - 'time': ['TIME', 'T_MIN', 'T_MAX'], - } + size = np.prod([ax.nbin for ax in axes]) + chan = np.arange(0, size) + cols = [fits.Column('CHANNEL', 'I', array=chan), ] if axis_names is None: axis_names = [ax.name for ax in axes] + axis_names = [_.upper() for _ in axis_names] - size = np.prod([ax.nbin for ax in axes]) - chan = np.arange(0, size) - cols = [fits.Column('CHANNEL', 'I', array=chan), ] axes_ctr = np.meshgrid(*[ax.center for ax in axes]) axes_min = np.meshgrid(*[ax.edges[:-1] for ax in axes]) axes_max = np.meshgrid(*[ax.edges[1:] for ax in axes]) + for i, (ax, name) in enumerate(zip(axes, axis_names)): - names = colname.get(name.lower(), - ['AXIS%i' % i, - 'AXIS%i_MIN' % i, 'AXIS%i_MAX' % i]) - for t, v in zip(names, [axes_ctr, axes_min, axes_max]): + if name == 'ENERGY': + colnames = ['ENERGY', 'E_MIN', 'E_MAX'] + elif name == 'TIME': + colnames = ['TIME', 'T_MIN', 'T_MAX'] + else: + s = 'AXIS%i' % i if name == '' else name + colnames = [s, s + '_MIN', s + '_MAX'] + + for colname, v in zip(colnames, [axes_ctr, axes_min, axes_max]): array = np.ravel(v[i]) unit = ax.unit.to_string() - cols.append(fits.Column(t, 'E', array=array, unit=unit)) + cols.append(fits.Column(colname, 'E', array=array, unit=unit)) return cols diff --git a/gammapy/maps/tests/test_wcsnd.py b/gammapy/maps/tests/test_wcsnd.py index 19727e88bb..7314baca07 100644 --- a/gammapy/maps/tests/test_wcsnd.py +++ b/gammapy/maps/tests/test_wcsnd.py @@ -17,7 +17,7 @@ pytest.importorskip('scipy') pytest.importorskip('reproject') -axes1 = [MapAxis(np.logspace(0., 3., 3), interp='log')] +axes1 = [MapAxis(np.logspace(0., 3., 3), interp='log', name='spam')] axes2 = [MapAxis(np.logspace(0., 3., 3), interp='log'), MapAxis(np.logspace(1., 3., 4), interp='lin')] skydir = SkyCoord(110., 75.0, unit='deg', frame='icrs')