Warping a uniformly-sampled disk to a uniformly-sampled sphere.

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

import plotly.graph_objects as go

In [2]:
num_points = 10000

seed = 12357


In [3]:
key = jax.random.PRNGKey(seed)

In [4]:
def in_circle(xy):
    return jnp.sum(xy**2, axis=1) <= 1

xy = jax.random.uniform(key, shape=(num_points, 2), minval=-1, maxval=1)
xy_c = xy[in_circle(xy)]

len(xy_c) / len(xy), np.pi/4

(0.7861, 0.7853981633974483)

In [5]:
# add z to make a cosine-weighted hemisphere (i.e. more points are near the pole)
z = jnp.sqrt(1 - jnp.sum(xy_c**2, axis=1))

In [6]:
# plot 3d points in plotly

fig = go.Figure(data=[go.Scatter3d(
    x=xy_c[:,0],
    y=xy_c[:,1],
    z=z,
    mode='markers',
    marker=dict(
        size=2,
        # color=z,                # set color to an array/list of desired values
        # colorscale='Viridis',   # choose a colorscale
        opacity=0.5
    )
)])

# tight layout
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()

In [7]:
z_s = -1 + 2 * z

# normalize xy such that resulting point is on the unit sphere
m = jnp.sqrt((1 - z_s**2) / jnp.sum(xy_c**2, axis=1))
xy_s = m[:, None] * xy_c

p = jnp.concatenate([xy_s, z_s[:,None]], axis=1)

In [17]:
## note that this cell is not the right way to do this!

n = jnp.array([0, 0, 1])

p1 = jnp.concatenate([xy_c, z[:, None]], axis=1)
p2 = (p + n[None, :]) / jnp.linalg.norm(p + n[None, :], axis=1)[:, None]

jnp.linalg.norm(p1 - p2, axis=1)

Array([0.07576162, 0.1783683 , 0.20610476, ..., 0.25828815, 0.26718742,
       0.27472296], dtype=float32)

In [8]:
jnp.min(z)

Array(0.01134926, dtype=float32)

In [9]:
jnp.linalg.norm(p, axis=1)

Array([1.        , 0.99999994, 1.        , ..., 0.99999994, 0.99999994,
       0.99999994], dtype=float32)

In [11]:
fig = go.Figure(data=[go.Scatter3d(
    x=p[:,0],
    y=p[:,1],
    z=p[:,2],
    mode='markers',
    marker=dict(
        size=1,
        # color=z,                # set color to an array/list of desired values
        # colorscale='Viridis',   # choose a colorscale
        opacity=0.5
    )
)])

# tight layout
# fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()

In [18]:
r_c = np.linalg.norm(xy_c, axis=1)

In [27]:
phi = jnp.arccos(z)

Array(False, dtype=bool)

In [28]:
z_s = jnp.cos(2 * phi)

In [29]:
r_s = jnp.sqrt(1 - z_s**2)

In [30]:
xy_s = r_s[:, None] * xy_c / r_c[:, None]

In [31]:
fig = go.Figure(data=[go.Scatter3d(
    x=xy_s[:,0],
    y=xy_s[:,1],
    z=z_s,
    mode='markers',
    marker=dict(
        size=1,
        # color=z,                # set color to an array/list of desired values
        # colorscale='Viridis',   # choose a colorscale
        opacity=0.5
    )
)])

# tight layout
# fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()