In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import os
import sys

import numpy as np
from numba import jit

import astropy.units as u
from astropy import wcs
from astropy.io import fits
from astropy.coordinates import SkyCoord
from astropy.visualization import make_lupton_rgb
from astropy.utils.data import download_file, clear_download_cache

import matplotlib.pyplot as plt
from matplotlib import rcParams

print(np.__version__)

1.19.5


In [2]:
@jit(nopython=True, nogil=True)
def cart2polar(x, y):
    """
    Transform Cartesian coordinates to polar
    
    Parameters
    ----------
    x, y : floats or arrays
        Cartesian coordinates
    
    Returns
    -------
    r, theta : floats or arrays
        Polar coordinates
        
    """
    r = np.hypot(x, y)
    theta = np.arctan2(y, x)  # θ referenced to vertical
    return r, theta

@jit(nopython=True, nogil=True)
def polar2cart(r, theta):
    """
    Transform polar coordinates to Cartesian
    
    Parameters
    -------
    r, theta : floats or arrays
        Polar coordinates
        
    Returns
    ----------
    x, y : floats or arrays
        Cartesian coordinates
    """
    y = r * np.cos(theta)   # θ referenced to vertical
    x = r * np.sin(theta)
    return x, y

def index_coords(data):
    """
    Creates x & y coords for the indicies in a numpy array
    
    Parameters
    ----------
    data : numpy array
        2D data
    origin : (x,y) tuple
        defaults to the center of the image. Specify origin=(0,0)
        to set the origin to the *bottom-left* corner of the image.
    
    Returns
    -------
        x, y : arrays
    """
    ny, nx = data.shape[:2]
    
    x, y = np.meshgrid(
        np.arange(float(nx)), np.arange(float(ny)))

    return x, y

@jit(nopython=True, nogil=True)
def index_to_galcen(x_idx, y_idx, x_cen, y_cen):
    return x_idx - x_cen, y_idx - y_cen

In [3]:
x, y = np.arange(100) * 1.0, np.arange(100) * 1.0

data = np.zeros([1000, 1000])

r, theta = cart2polar(x, y)
xx, yy = polar2cart(r, theta)

xid, yid = index_coords(data)

In [4]:
%timeit r, thetha = cart2polar(x, y)

2.17 µs ± 22.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [18]:
%timeit xx, yy = polar2cart(r, theta)

3.89 µs ± 17.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [55]:
%timeit xid, yid = index_coords(data)

1.6 ms ± 38.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [69]:
%timeit x0, y0 = index_to_galcen(xid, yid, 500.0, 500.0)

2.21 ms ± 26.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [24]:
#@jit(nopython=True, nogil=True)
def cart2polar(x, y):
    """
    Transform Cartesian coordinates to polar
    
    Parameters
    ----------
    x, y : floats or arrays
        Cartesian coordinates
    
    Returns
    -------
    r, theta : floats or arrays
        Polar coordinates
        
    """
    return np.hypot(x, y), np.arctan2(y, x)

In [25]:
%timeit cart2polar(x, y)

2 µs ± 17.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [5]:
img = fits.open("/Users/song/Dropbox/work/project/huoguo/huoguo/data/M51.fits")[0].data

In [6]:
x, y = np.arange(100) * 1.0, np.arange(100) * 1.0

In [7]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [28]:
@jit
def img_polar_coord(img):
    ny, nx = img.shape
    xid, yid = jnp.meshgrid(
        jnp.arange(nx), jnp.arange(ny))
    xid -= (nx / 2.)
    yid -= (ny / 2.)
    return jnp.hypot(xid, yid), jnp.arctan2(yid, xid)

In [30]:
%timeit r, theta = img_polar_coord(img.astype(np.float32))

2.03 ms ± 7.85 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [32]:
r, theta = img_polar_coord(img.astype(np.float32))

In [33]:
r

DeviceArray([[362.03867, 361.33224, 360.62723, ..., 359.9236 , 360.62723,
              361.33224],
             [361.33224, 360.62445, 359.91806, ..., 359.213  , 359.91806,
              360.62445],
             [360.62723, 359.91806, 359.21024, ..., 358.5038 , 359.21024,
              359.91806],
             ...,
             [359.9236 , 359.213  , 358.5038 , ..., 357.79602, 358.5038 ,
              359.213  ],
             [360.62723, 359.91806, 359.21024, ..., 358.5038 , 359.21024,
              359.91806],
             [361.33224, 360.62445, 359.91806, ..., 359.213  , 359.91806,
              360.62445]], dtype=float32)

In [34]:
theta

DeviceArray([[-2.3561945 , -2.3542376 , -2.352273  , ..., -0.791292  ,
              -0.7893197 , -0.7873551 ],
             [-2.3581514 , -2.3561945 , -2.35423   , ..., -0.78933513,
              -0.7873628 , -0.7853982 ],
             [-2.360116  , -2.358159  , -2.3561945 , ..., -0.78737056,
              -0.7853982 , -0.78343356],
             ...,
             [ 2.3620884 ,  2.3601315 ,  2.358167  , ...,  0.7853982 ,
               0.7834258 ,  0.7814612 ],
             [ 2.360116  ,  2.358159  ,  2.3561945 , ...,  0.78737056,
               0.7853982 ,  0.78343356],
             [ 2.3581514 ,  2.3561945 ,  2.35423   , ...,  0.78933513,
               0.7873628 ,  0.7853982 ]], dtype=float32)

In [57]:
%time r, theta = cart2polar(xid - 256., yid - 256.)

CPU times: user 7.92 ms, sys: 1.88 ms, total: 9.8 ms
Wall time: 8.66 ms
