## Install Packages



In [1]:
# ! git clone https://github.com/oxcsml/geomstats.git
! pip3 install ./geomstats
! pip3 install jax matplotlib diffrax flax einops tqdm wandb seaborn
#! pip3 install matplotlib==3.1.3 

Processing ./geomstats
  Preparing metadata (setup.py) ... [?25ldone
Collecting scikit-learn>=0.22.1
  Using cached scikit_learn-1.1.2-cp310-cp310-macosx_12_0_arm64.whl (7.7 MB)
Collecting scipy<1.9,>=1.4.1
  Downloading scipy-1.8.1-cp310-cp310-macosx_12_0_arm64.whl (28.7 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m28.7/28.7 MB[0m [31m631.2 kB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:02[0m
Collecting threadpoolctl>=2.0.0
  Using cached threadpoolctl-3.1.0-py3-none-any.whl (14 kB)
Building wheels for collected packages: geomstats
  Building wheel for geomstats (setup.py) ... [?25ldone
[?25h  Created wheel for geomstats: filename=geomstats-2.5.0-py3-none-any.whl size=10076101 sha256=ef9922ac35e2222ddafb61c0ced2ebd825ac8ab6b3ebbf699a6e035c964b1031
  Stored in directory: /private/var/folders/6s/zv4ygprx6jvg4d7jsgbjfc1r0000gn/T/pip-ephem-wheel-cache-_lzahqoi/wheels/e2/40/0b/00724ed1f42afc9561b29ee17a11a7912fb957685eb9614a91
Successfu

In [2]:
%matplotlib inline

## Data

Setting up the utilities to generate random real numbers $\nu \in (0, 10)$ as well as a random maximum denominator $q_{max}$ are sampled. Once these are chosen, two integers $p, q$ with $ q \le q_{max}$ are chosen such that $\frac{p}{q} \approx \nu$.

Currently representing the integers $p, q$ in binary. This may be subject to change. 

In [3]:
from dataclasses import dataclass
from fractions import Fraction
from random import choice, choices
import numpy as np


def int_to_binary(x, width=14):
    return np.array(list(np.binary_repr(x, width=width)), dtype=np.uint8)


@dataclass
class RationalApprox:
    target_real: str
    frac: Fraction
    dtype = np.uint32

    @property
    def numerator(self):
        return self.frac.numerator

    @property
    def denominator(self):
        return self.frac.denominator

    def approximation(self):
        num = self.frac.numerator * 1.
        denom = self.frac.denominator
        return num / denom

    def to_numpy(self):
        num = int_to_binary(self.numerator)
        denom = int_to_binary(self.denominator)
        return np.stack((num, denom))

    def for_batch(self):
        return self.target_real, self.to_numpy()


def rand_frac(decimal_places=15, max_denom=1024):
    # TODO: Refactor this to be in terms of numpy or jax rng for proper reproducibility
    digits = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    leading_digit = choices(digits, weights=[4, 1, 1, 1, 1, 1, 1, 1, 1, 1])[0]
    remaining = ''.join(choices(digits, k=decimal_places))
    number = f'{leading_digit}.{remaining}'
    return RationalApprox(number, Fraction(number).limit_denominator(max_denom))


def make_batch(batch_size=128):
    denominators = [4, 8, 16, 32, 64, 128, 256, 512, 1024]
    numbers = [rand_frac(max_denom=choice(denominators)).for_batch() for _ in range(batch_size)]
    dec_strings, fractions = zip(*numbers)
    return np.array(dec_strings, dtype=np.float32), np.stack(fractions, axis=0)

## Manifold Random Walks


In [4]:
import os

os.environ["GEOMSTATS_BACKEND"] = "jax"

import jax
import jax.numpy as jnp
import diffrax

In [5]:
from diffrax import diffeqsolve, ControlTerm, Euler, MultiTerm, ODETerm, SaveAt, VirtualBrownianTree
import geomstats.visualization as visualization
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.geometry.product_manifold import ProductSameManifold, ProductSameRiemannianMetric
import geomstats.backend as gs

INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO:absl:Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices.
INFO:absl:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host
INFO:absl:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO:root:Using jax backend


In [6]:
def geodesic_random_walk(manif, num_steps, time, rng_key, x0=None):
    step_size = time / num_steps
    gamma = jnp.sqrt(step_size)
    tangent_dim = manif.embedding_space.dim
    if x0 is None:
        rng_key, x0_key = jax.random.split(rng_key)
        x0 = manif.random_uniform(state=x0_key)
    def grw_step(carry, rv):
        tangent_rv = gamma * manif.to_tangent(rv, x)
        x_new = manif.exp(tangent_rv, carry)
        return x_new, x_new
    rvs = jax.random.normal(rng_key, (num_steps, tangent_dim))
    return jax.lax.scan(grw_step, x0, rvs)

In [9]:
x = jnp.array([1., 0., 0.])
key = jax.random.PRNGKey(314)
key, subkey = jax.random.split(key)
manif = Hypersphere(2)
grw = geodesic_random_walk(manif, 1000, 8, subkey, x0=x)

In [10]:
grw

(DeviceArray([ 0.00349199, -0.67380184,  0.7389034 ], dtype=float32),
 DeviceArray([[ 9.9960518e-01,  2.8096525e-02, -1.0146119e-04],
              [ 9.9198759e-01,  3.8272552e-02, -1.2039840e-01],
              [ 9.8407435e-01, -5.4092366e-02, -1.6932705e-01],
              ...,
              [ 3.6974836e-03, -6.0509509e-01,  7.9614425e-01],
              [ 3.5356847e-03, -6.7413783e-01,  7.3859668e-01],
              [ 3.4919900e-03, -6.7380184e-01,  7.3890340e-01]],            dtype=float32))

In [11]:
import geomstats.visualization as visualization
#import matplotlib.animation as animation
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns


def remove_background(ax):
    ax.set_axis_off()
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    return ax


def latlon_from_cartesian(points):
    r = jnp.linalg.norm(points, axis=-1)
    x = points[..., 0]
    y = points[..., 1]
    z = points[..., 2]

    lat = -jnp.arcsin(z / r)
    lon = jnp.arctan2(y, x)
    # lon = jnp.where(lon > 0, lon - math.pi, lon + math.pi)
    return jnp.concatenate([jnp.expand_dims(lat, -1), jnp.expand_dims(lon, -1)], axis=-1)


def cartesian_from_latlong(points):
    lat = points[..., 0]
    lon = points[..., 1]

    x = jnp.cos(lat) * jnp.cos(lon)
    y = jnp.cos(lat) * jnp.sin(lon)
    z = jnp.sin(lat)

    return jnp.stack([x, y, z], axis=-1)


def get_spherical_grid(N, eps=0.0):
    lat = jnp.linspace(-90 + eps, 90 - eps, N // 2)
    lon = jnp.linspace(-180 + eps, 180 - eps, N)
    Lat, Lon = jnp.meshgrid(lat, lon)
    latlon_xs = jnp.concatenate([Lat.reshape(-1, 1), Lon.reshape(-1, 1)], axis=-1)
    spherical_xs = jnp.pi * (latlon_xs / 180.0) + jnp.array([jnp.pi / 2, jnp.pi])[None, :]
    xs = Hypersphere(2).spherical_to_extrinsic(spherical_xs)
    return xs, lat, lon


def plot_3d(x0s, xts, size, prob):
    fig = plt.figure(figsize=(size, size))
    ax = fig.add_subplot(111, projection="3d")
    ax = remove_background(ax)
    fig.subplots_adjust(left=-0.2, bottom=-0.2, right=1.2, top=1.2, wspace=0, hspace=0)
    # ax.view_init(elev=30, azim=45)
    ax.view_init(elev=0, azim=0)
    cmap = sns.cubehelix_palette(as_cmap=True)
    sphere = visualization.Sphere()
    sphere.draw(ax, color="red", marker=".")
    # sphere_plot(ax)
    # sphere.plot_heatmap(ax, pdf, n_points=16000, alpha=0.2, cmap=cmap)
    for k, (x0, xt) in enumerate(zip(x0s, xts)):
        if x0 is not None:
            cax = ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], s=50, color="green")
        if xt is not None:
            x, y, z = xt[:, 0], xt[:, 1], xt[:, 2]
            c = prob if prob is not None else np.ones([*xt.shape[:-1]])
            cax = ax.scatter(x, y, z, s=50, vmin=0.0, vmax=2.0, c=c, cmap=cmap)
        # if grad is not None:
        #     u, v, w = grad[:, 0], grad[:, 1], grad[:, 2]
        #     quiver = ax.quiver(
        #         x, y, z, u, v, w, length=0.2, lw=2, normalize=False, cmap=cmap
        #     )
        #     quiver.set_array(c)

    plt.colorbar(cax)
    # plt.savefig(out, dpi=dpi, bbox_inches="tight", transparent=True)
    plt.close(fig)
    return fig

INFO:numexpr.utils:NumExpr defaulting to 2 threads.


In [16]:
size = 7
plt.switch_backend("agg")
fig = plt.figure(figsize=(size, size))
ax = fig.add_subplot(111, projection="3d")
ax = remove_background(ax)


In [17]:
path = grw[1]

In [21]:
%matplotlib inline
fig.subplots_adjust(left=-0.2, bottom=-0.2, right=1.2, top=1.2, wspace=0, hspace=0)
    # ax.view_init(elev=30, azim=45)
ax.view_init(elev=0, azim=0)
cmap = sns.cubehelix_palette(as_cmap=True)


ax.scatter(path[:, 0], path[:, 1], path[:, 2])
plt.show()

In [22]:
path

DeviceArray([[ 9.9960518e-01,  2.8096525e-02, -1.0146119e-04],
             [ 9.9198759e-01,  3.8272552e-02, -1.2039840e-01],
             [ 9.8407435e-01, -5.4092366e-02, -1.6932705e-01],
             ...,
             [ 3.6974836e-03, -6.0509509e-01,  7.9614425e-01],
             [ 3.5356847e-03, -6.7413783e-01,  7.3859668e-01],
             [ 3.4919900e-03, -6.7380184e-01,  7.3890340e-01]],            dtype=float32)