<a href="https://colab.research.google.com/github/dashstander/simplex-score-matching/blob/main/division.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install Packages



In [1]:
! git clone https://github.com/oxcsml/geomstats.git
! pip3 install ./geomstats
! pip3 install diffrax flax einops tqdm wandb seaborn

Cloning into 'geomstats'...
remote: Enumerating objects: 34627, done.[K
remote: Counting objects: 100% (29/29), done.[K
remote: Compressing objects: 100% (12/12), done.[K
remote: Total 34627 (delta 18), reused 26 (delta 17), pack-reused 34598[K
Receiving objects: 100% (34627/34627), 92.06 MiB | 22.30 MiB/s, done.
Resolving deltas: 100% (26313/26313), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing ./geomstats
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Collecting matplotlib>=3.3.4
  Downloading matplotlib-3.5.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting flax
  Downloading flax-0.6.0-py3-none-any.whl (180 kB)
[K     |████████████████████████████████| 180 kB 2.1 MB/s 
[?25hCollecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Collecting wandb
  Downloading wandb-0.13.3-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 22.1 MB/s 
Collecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 67.7 MB/s 
Collecting rich~=11.1
  Downloading rich-11.2.0-py3-none-any.whl (217 kB)
[K     |████████████████████████████████| 217 kB 54.7 MB/s 
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
[K     |████████████████████████████████| 51 kB 6.7 MB/s 
[?25hCollecting colorama<0.5.0,>=0.4.0
  Downloading colorama-0.4.5-py2.py3-none-any.whl (16 kB)
Collecting GitPython>=1.0.0
  Downloading Gi

## Data

Setting up the utilities to generate random real numbers $\nu \in (0, 10)$ as well as a random maximum denominator $q_{max}$ are sampled. Once these are chosen, two integers $p, q$ with $ q \le q_{max}$ are chosen such that $\frac{p}{q} \approx \nu$.

Currently representing the integers $p, q$ in binary. This may be subject to change. 

In [None]:
from dataclasses import dataclass
from fractions import Fraction
from random import choice, choices
import numpy as np


def int_to_binary(x, width=14):
    return np.array(list(np.binary_repr(x, width=width)), dtype=np.uint8)


@dataclass
class RationalApprox:
    target_real: str
    frac: Fraction
    dtype = np.uint32

    @property
    def numerator(self):
        return self.frac.numerator

    @property
    def denominator(self):
        return self.frac.denominator

    def approximation(self):
        num = self.frac.numerator * 1.
        denom = self.frac.denominator
        return num / denom

    def to_numpy(self):
        num = int_to_binary(self.numerator)
        denom = int_to_binary(self.denominator)
        return np.stack((num, denom))

    def for_batch(self):
        return self.target_real, self.to_numpy()


def rand_frac(decimal_places=15, max_denom=1024):
    # TODO: Refactor this to be in terms of numpy or jax rng for proper reproducibility
    digits = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    leading_digit = choices(digits, weights=[4, 1, 1, 1, 1, 1, 1, 1, 1, 1])[0]
    remaining = ''.join(choices(digits, k=decimal_places))
    number = f'{leading_digit}.{remaining}'
    return RationalApprox(number, Fraction(number).limit_denominator(max_denom))


def make_batch(batch_size=128):
    denominators = [4, 8, 16, 32, 64, 128, 256, 512, 1024]
    numbers = [rand_frac(max_denom=choice(denominators)).for_batch() for _ in range(batch_size)]
    dec_strings, fractions = zip(*numbers)
    return np.array(dec_strings, dtype=np.float32), np.stack(fractions, axis=0)

## Manifold Random Walks


In [2]:
import os

os.environ["GEOMSTATS_BACKEND"] = "jax"

import jax
import jax.numpy as jax
import diffrax

In [4]:
from diffrax import diffeqsolve, ControlTerm, Euler, MultiTerm, ODETerm, SaveAt, VirtualBrownianTree
import geomstats.visualization as visualization
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.geometry.product_manifold import ProductSameManifold, ProductSameRiemannianMetric
import geomstats.backend as gs