In [1]:
from gammapy.maps import Map, WcsGeom, MapAxis, WcsNDMap
import jax.numpy as jnp
import numpy as np
import astropy.units as u
from gammapy.maps.core import USE_JAX

NP = jnp if USE_JAX else np 

In [2]:
axis_1 = MapAxis.from_edges([1, 2, 3, 4, 5], name="test")
geom = WcsGeom.create(skydir=(0,0), npix=10, proj='CAR', axes=[axis_1])
geom_allsky = WcsGeom.create(proj='AIT', axes=[axis_1])


### Map creation

In [3]:
m1 = Map.from_geom(geom, data=3600, unit="s")
m2 = Map.from_geom(geom, data=1, unit="m2" )
m3 = Map.from_geom(geom, data=1, unit="cm2" )
m_allsky = Map.from_geom(geom_allsky)

In [4]:
print(m1.data.__class__)
print(m2.data.__class__)
print(m_allsky.data.__class__)

<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>


### Map Operation

In [5]:
expo = m1*m2
print(expo)
print(expo.data.__class__)

WcsNDMap

	geom  : WcsGeom 
 	axes  : ['lon', 'lat', 'test']
	shape : (10, 10, 4)
	ndim  : 3
	unit  : m2 s
	dtype : int64

<class 'jaxlib.xla_extension.ArrayImpl'>


In [6]:
m4 = m2 + m3
print(m4.quantity[0,0])
print(m4.data.__class__)

[1.0001 1.0001 1.0001 1.0001 1.0001 1.0001 1.0001 1.0001 1.0001 1.0001] m2
<class 'jaxlib.xla_extension.ArrayImpl'>


In [7]:
mask = m1>1*u.d
print(mask.data.__class__)

<class 'jaxlib.xla_extension.ArrayImpl'>


### Other operations

In [8]:
# Cutout
mcut = m1.cutout(geom.center_skydir, 2*u.deg)
print(mcut.data.__class__)

<class 'jaxlib.xla_extension.ArrayImpl'>


In [9]:
# Resample axis
m_resample = m1.resample_axis(axis_1.downsample(2))
print(m_resample.data.__class__)

<class 'jaxlib.xla_extension.ArrayImpl'>


In [10]:
# Stack
m1.stack(mcut)
m1.data[0,5,5]==7200

Array(True, dtype=bool)

### Accessing/setting elements

#### By pixels coordinates

In [11]:
pix = m1.geom.get_pix()
print(m1.get_by_pix(pix).__class__)

<class 'jaxlib.xla_extension.ArrayImpl'>


In [12]:
vals = NP.ones_like(pix[0], dtype=NP.float64)
m1.set_by_pix(pix, vals)
print(m1.data.__class__)

<class 'jaxlib.xla_extension.ArrayImpl'>




#### by coord

In [13]:
coords = m1.geom.get_coord()
print(m1.get_by_coord(coords, coords[0]).__class__)

<class 'jaxlib.xla_extension.ArrayImpl'>


In [14]:
vals *= 2
m1.set_by_coord(coords, vals)
m1.data[0,0]

Array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int64)

#### With AIT projection?

In [15]:
coords = m_allsky.geom.get_coord()
vals = m_allsky.get_by_coord(coords)
print(vals.__class__)
NP.isnan(vals[0,0,0])

<class 'jaxlib.xla_extension.ArrayImpl'>


Array(True, dtype=bool)

In [16]:
## This does not work
vals = NP.ones_like(coords['lat'], dtype=NP.float32) * 2
m_allsky.set_by_coord(coords, vals)

### Convolution?

In [17]:
m1 = Map.from_geom(geom, data=0, unit="s")

In [18]:
m1.set_by_coord({'lon':0.*u.deg, 'lat':0.*u.deg, 'test':2}, 1)

In [20]:
# Returns a ndarray
m1.smooth("0.3 deg").data.__class__

numpy.ndarray