# Morton Curves

In [None]:
import matplotlib.pyplot as plt
import matplotlib.transforms as transforms
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

## 2D

In [None]:
def morton(x: int, y: int) -> int:
    return (_part1by1(x) << 1) | _part1by1(y)


def imorton(index: int) -> tuple[int, int]:
    x = _compact1by1(index >> 1)
    y = _compact1by1(index)
    return x, y


def _part1by1(n):
    n &= 0x0000FFFF
    n = (n | (n << 8)) & 0x00FF00FF
    n = (n | (n << 4)) & 0x0F0F0F0F
    n = (n | (n << 2)) & 0x33333333
    n = (n | (n << 1)) & 0x55555555
    return n


def _compact1by1(n):
    n &= 0x55555555
    n = (n ^ (n >> 1)) & 0x33333333
    n = (n ^ (n >> 2)) & 0x0F0F0F0F
    n = (n ^ (n >> 4)) & 0x00FF00FF
    n = (n ^ (n >> 8)) & 0x0000FFFF
    return n

In [None]:
n = 8
Ax, Ay = np.mgrid[:n, :n]
A = n * Ax + Ay
A

In [None]:
B = morton(Ax, Ay)
Bx = B // n
By = B % n
B
print(B)

In [None]:
a = transforms.Affine2D()
a.inverted

In [None]:
C = np.array([imorton(i) for i in range(n * n)])

In [None]:
def swap(x, y, delta=0.0):
    return y + delta, x + delta

In [None]:
fig, ax = plt.subplots(1, 1)
ax.xaxis.tick_top()
# plt.imshow(B)
ax.invert_yaxis()
plt.xticks(a := np.arange(-0.5, n + 0.5, 1.0), ["" for _ in a])
plt.yticks(a := np.arange(-0.5, n + 0.5, 1.0), ["" for _ in a])
plt.xticks(a := np.arange(0, n, 1), [f"{i}" for i in a], minor=True)
plt.yticks(a := np.arange(0, n, 1), [f"{i}" for i in a], minor=True)
plt.plot(*swap(*C.T), color="black")
plt.axis("scaled")
plt.grid(True)
plt.show()

## 3D

In [None]:
def morton3D(x: int, y: int, z: int) -> int:
    return (_part1by2(x) << 2) | (_part1by2(y) << 1) | (_part1by2(z) << 0)


def imorton3D(index: int) -> tuple[int, int, int]:
    z = _compact1by2(index >> 0)
    y = _compact1by2(index >> 1)
    x = _compact1by2(index >> 2)
    return x, y, z


def _part1by2(n):
    n &= 0x000003FF  # only 10 bits
    n = (n | (n << 16)) & 0x030000FF
    n = (n | (n << 8)) & 0x0300F00F
    n = (n | (n << 4)) & 0x030C30C3
    n = (n | (n << 2)) & 0x09249249
    return n


def _compact1by2(n):
    n &= 0x09249249
    n = (n ^ (n >> 2)) & 0x030C30C3
    n = (n ^ (n >> 4)) & 0x0300F00F
    n = (n ^ (n >> 8)) & 0x030000FF
    n = (n ^ (n >> 16)) & 0x000003FF
    return n

In [None]:
m = 8
C = np.array([imorton3D(i) for i in range(m * m * m)])


# Sample data
x = np.random.rand(100)
y = np.random.rand(100)
z = np.random.rand(100)

# Create 3D figure
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")

# Plot the points
ax.plot(*C.T)

# Labels
ax.set_xlabel("X Axis")
ax.set_ylabel("Y Axis")
ax.set_zlabel("Z Axis")

plt.show()