<a href="https://colab.research.google.com/github/drscook/Billiards2020_with_Jax/blob/master/billiards2020.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
PROJ_SUBPATH = 'My Drive/active/billiards2020/'
ROOT_PATH = '/content/drive/'
PROJ_PATH = ROOT_PATH + PROJ_SUBPATH
from google.colab import drive
drive.mount(ROOT_PATH)
%cd '{PROJ_PATH}'
%run git.ipynb
%cd "{GIT_PATH}"
%run helpers.ipynb

In [None]:
num_part=1000
rng = numpy_or_jax(num_part)
# rng = pick_jnp(force_jax=True)



In [None]:
git_push()

In [None]:
# ! pip install ipython-autotime
# %load_ext autotime
# ! pip install codetiming
# from codetiming import Timer

from math import log10, ceil
import numpy as np
import scipy.stats as stats
try:
    import jax
    import jaxlib
    import jax.numpy as jnp
    jax.config.update('jax_enable_x64', True)
    print(f"Using jax {jax.__version__} and jaxlib {jaxlib.__version__}")
    JAX_AVAILABLE = True
except:
    print('Error importing jax; importing numpy as jnp instead')
    import numpy as jnp
    JAX_AVAILABLE = False


##################### Utility Functions #####################
def arrify(x):
    try:
        x.shape
    except:
        x = jnp.asarray(x) 
    if x.ndim == 0:
        x = x.reshape(-1)
    return x

def metric_prefix(x, sigfigs=3):
    s = round(sigfigs)
    assert s > 0, f'sigfigs must round to > 0, got {sigfigs}'
    prefixes = [(12,'T'), (9,'G'), (6,'M'), (3,'k'), (0,''), (-3,'m'), (-6,'µ'), (-9,'n'), (-12,'p')]
    a = ceil(log10(x))
    for b, pref in prefixes:
        if a > b:
            c = s - a + b
            y = round(x/(10**b), c)
            if c <= 0:
                num = int(y)
            else:
                num = format(y, f'.{c}f')
            return f'{num} {pref}'

##################### Numpy & Jax Standardization #####################
def outer(x, y, func=jnp.multiply):
    a, b = arrify(x), arrify(y)
    a, b = a.reshape(list(a.shape) + [1] * b.ndim), b.reshape(a.ndim * [1] + list(b.shape))
    return func(a, b)

def put(arr, vals=0, idx=None):
    try:
        arr = arr.at[idx].set(vals)  # jax
    except:
        arr[idx] = vals  # numpy
    return arr

def add(arr, vals=0, idx=None):
    try:
        arr = arr.at[idx].add(vals)  # jax
    except:
        arr[idx] += vals  # numpy
    return arr

def mul(arr, vals=1, idx=None):
    try:
        arr = arr.at[idx].mul(vals)  # jax
    except:
        arr[idx] *= vals  # numpy
    return arr

def minimum(arr, vals=np.PINF, idx=None):
    try:
        arr = arr.at[idx].min(vals)  # jax
    except:
        arr[idx] = jnp.minimum(arr[idx], vals)  # numpy
    return arr

def maximum(arr, vals=np.NINF, idx=None):
    try:
        arr = arr.at[idx].max(vals)  # jax
    except:
        arr[idx] = jnp.maximum(arr[idx], vals)  # numpy
    return arr

class jax_RandomState():
    """
    Jax rng object with same syntax as np.randomState
    """
    def __init__(self, seed=42):
        self.seed = seed
        self.key = key = jax.random.PRNGKey(seed)
        self.tol = 1e-5

    def randint(self, low, high=None, size=1, dtype=None):
        if high is None:
            high = low
            low = 0
        if low > high:
            print("specified low > high; I'll swap them")
            low, high = high, low
        return jax.random.randint(self.key, minval=low, maxval=high,
                                  shape=arrify(size), dtype=dtype)

    def uniform(self, low=0, high=1, size=1, dtype=None):
        return jax.random.uniform(self.key, minval=low, maxval=high,
                                  shape=arrify(size))#, dtype=dtype)

    def normal(self, loc=0, scale=1, size=1):
        x = jax.random.normal(self.key, shape=arrify(size))
        if abs(scale - 1.0) > self.tol:
            x *= scale
        if abs(loc) > self.tol:
            x += loc
        return x

    def multivariate_normal(self, mean=0, cov=None, size=1, dtype=None):
        mean = arrify(mean)
        dim = mean.shape[0]
        if cov is None:
            cov = jnp.eye(dim)
            print(f"cov not specified; using identity matrix")
        cov = self._arrayify(cov)
        sh = cov.shape
        if (len(sh) != 2) or (sh[0] != dim) or (sh[0] != sh[1]):
            raise Exception(f"cov has shape {sh}; should be ({dim}, {dim})")
        if dim == 1:
            return self.normal(loc=mean[0], scale=cov[0,0], size=size, dtype=dtype)
        else:
            return jax.random.multivariate_normal(self.key, mean=mean, cov=cov,
                                                  shape=arrify(size), dtype=dtype)

    def choice(self, a, replace=True, p=None, size=1):
        try:
            return jax.random.choice(self.key, a, replace=replace, p=p, shape=arrify(size))
        except Exception as e:
            print(f"You're running jax {jax.__version__}, but random.choice appears to be implemented in jax 0.1.71.  Sadly, jax > 0.1.69 can not utilize colab GPU (as of 2020-07-11).")
            print(e)
        return 

    def permutation(self, x):
        return jax.random.permutation(self.key, x)

def pick_jnp(num_part=1, seed=42, force_numpy=False, force_jax=False):
    """
    Determines whether to use jax or numpy.  Can force it, or allow detection.
    """
    global jnp, USE_JAX, nx, jx
    if (force_jax is True) and (JAX_AVAILABLE is True):
        print("Given force_jax = True")
        USE_JAX = True
    elif force_numpy is True:
        print("Given force_numpy = True")
        USE_JAX = False
    elif JAX_AVAILABLE is False:
        print("Jax not available")
        USE_JAX = False
    else:
        max_time = 4
        import numpy as np
        @Timer(name='np', loop_time=max_time, quiet=True)
        def f_np(x):
            return outer(x, x, np.subtract)

        import jax.numpy as jnp
        @Timer(name='jax', loop_time=max_time, quiet=True)
        def f_jax(x):
            return outer(x, x, jnp.subtract)

        nx = np.random.uniform(size=num_part); jx = jnp.array(nx)
        f_np(nx); f_jax(jx)
        nt, jt = Timer.timers['np'], Timer.timers['jax']
        USE_JAX = jt < nt
        print(f"For {num_part} particles, numpy time = {metric_prefix(nt)}s vs jax time = {metric(jt)}s")
    
    if USE_JAX is True:
        print("importing jax.numpy as jnp.")
        import jax.numpy as jnp
        rng = jax_RandomState(seed)
    else:
        print("importing numpy as jnp.")
        import numpy as jnp
        rng = np.random.RandomState(seed)
    return rng

In [None]:
! pip install datadict
import functools
import time
from dataclasses import field#, dataclass
from datadict import dataclass
from typing import Any, Callable, ClassVar, Dict, Optional

class TimerError(Exception):
    """A custom exception used to report errors in use of Timer class"""

@dataclass
class Timer:
    """Time your code using a class, context manager, or decorator"""

    timers: ClassVar[Dict[str, float]] = dict()
    name: Optional[str] = None
    # text: str = "Mean elapsed time: {:0.4f} seconds over {} loops"
    text: str = "Mean elapsed time: {}s over {} loops"
    logger: Optional[Callable[[str], None]] = print
    _start_time: Optional[float] = field(default=None, init=False, repr=False)
    loop_time: float = -1.0
    quiet: bool = False

    def __post_init__(self) -> None:
        """Initialization: add timer to dict of timers"""
        if self.name:
            self.timers.setdefault(self.name, 0)
        self.loops = 0
        self.elapsed_time = 0.0

    def __mean_time__(self) -> float:
        if self.loops <= 0:
            self.mean_time = 0.0
        else:
            self.mean_time = self.elapsed_time / self.loops
        return self.mean_time

    def start(self) -> None:
        """Start a new timer"""
        if self._start_time is not None:
            raise TimerError(f"Timer is running. Use .stop() to stop it")
        self._start_time = time.perf_counter()

    def stop(self) -> float:
        """Stop the timer"""
        if self._start_time is None:
            raise TimerError(f"Timer is not running. Use .start() to start it")

        # Calculate elapsed time
        self.elapsed_time += (time.perf_counter() - self._start_time)
        self.loops += 1
        self._start_time = None
        if self.name:
            self.timers[self.name] = self.__mean_time__()
        return self.__mean_time__()

    def __report__(self) -> None:
        # Report elapsed time
        decimals = 2
        if self.logger:
            self.logger(self.text.format(metric_prefix(self.__mean_time__()), self.loops))

    def end(self) -> float:
        """Stop the timer, and report the elapsed time"""
        if self._start_time is not None:
            self.stop()
        if not self.quiet:
            self.__report__()
        return self.__mean_time__()

    def __enter__(self) -> "Timer":
        """Start a new timer as a context manager"""
        self.start()
        return self

    def __exit__(self, *exc_info: Any) -> None:
        """Stop the context manager timer"""
        self.end()

    def __call__(self, func) -> "Decorator":
        """Support using Timer as a decorator w/added looping functionality"""
        @functools.wraps(func)
        def wrapper_timer(*args, **kwargs):
            if self.loop_time > 0:
                # Execute function once to clear initial overhead before timing
                out = func(*args, **kwargs)
                while self.elapsed_time < self.loop_time:
                    self.start()
                    func(*args, **kwargs)
                    self.stop()
                self.end()
                return out
            else:
                with self:
                    return func(*args, **kwargs)
        return wrapper_timer

In [None]:
num_part=1000
# rng = pick_jnp(num_part)
rng = pick_jnp(force_jax=True)
jx = rng.uniform(size=num_part)
# jx.shape
@Timer(name='put', loop_time=4)
def p(jx):
    put(jx, 0, slice(0, num_part, 2))
p(jx)

In [None]:
# class BaseClass():
#     def __getitem__(self, key):
#         return getattr(self,key)

#     def __setitem__(self, key, val):
#         setattr(self, key, val)

#     def __delitem__(self, key):
#         delattr(self,key)

from datadict import dataclass
@dataclass
class Particles():#BaseClass):
    dim: int        = field(default=3,   metadata={'grp':'invariant'})
    num: int        = field(default=1,   metadata={'grp':'invariant'})
    collisions: int = field(default=0,   metadata={'grp':'scalar'})
    mass: float     = field(default=1.0, metadata={'grp':'scalar'})
    radius: float   = field(default=1.0, metadata={'grp':'scalar'})
    gamma: float    = field(default=0.5, metadata={'grp':'scalar'})
    pos: float      = field(default=0.0, metadata={'grp':'vector'})
    vel: float      = field(default=0.0, metadata={'grp':'vector'})
    rot: float      = field(default=0.0, metadata={'grp':'vector'})
    spin: float     = field(default=0.0, metadata={'grp':'vector'})

    def __post_init__(self):
        for grp in ['invariants', 'scalars', 'vectors']:
            self[grp] = {key:val.type for key, val in Particles.__dataclass_fields__.items() if val.metadata['grp'] == grp[:-1]}

        # for key in {**self.invariants, **self.scalars}.keys():
        for key in [*self.invariants.keys(), *self.scalars.keys()]:
            assert (jnp.asarray(self[key]) >= 0).all(), f"{key} must be >= 0; got {self[key]}"

        for key in ['gamma']:
            assert (jnp.asarray(self[key]) <= 1).all(), f"{key} must be <= 1; got {self[key]}"

        for key, dtype in self.scalars.items():
            # if self[key] has fewer than self.num values, repeat the last value
            vals = arrify(self[key]).ravel()[:self.num]   # convert to array, flatten, remove values beyond number of particles
            temp = jnp.full(self.num, vals[-1], dtype)   # temp array filled with last value of vals
            self[key] = put(temp, vals, idx=slice(len(vals)))   # stick vals into start of temp; leave trailing repeats

        for key, dtype in self.vectors.items():
            self[key] = jnp.full((self.num, self.dim), np.PINF, dtype)

        # self.beta = 2 * jnp.arctan(self.gamma)
        # self.moment_of_interia = self.mass * (self.gamma * self.radius) ** 2
        self# g = 1 + self.gamma ** 2
        # a = (1 - self.gamma ** 2) / g
        # b = 2 * self.gamma / g


    def __str__(self):
        s = "Particles: "
        for key in self.invariants:
            s += f'{key}={self[key]}, '
        
        for key in self.scalars:
            # Truncate with "..." when values repeat
            val = self[key]
            idx = [True, *(jnp.diff(val) != 0)]   # False where equal to prior value
            s += f'{key}={val[idx]}, '
            if not all(idx):
                s = s[:-3] + " ...], "   # if truncated, insert "..."
        return s[:-2]

    def __repr__(self):
        s = self.__str__()
        for key in self.vectors:
            s += f'\n{key}\n{self[key]}'
        return s


rng = pick_jnp(force_jax=True)#, force_numpy=True)
part = Particles(num=4, collisions=23.3, mass=[3,6,9,12], gamma=0.75)


In [None]:
num_particles = 20
defaults = {'mass':2, 'radius':10}
for key, val in defaults.items():
    particle[key] = np.full(num_particles, val)

In [None]:
num_particles = 20
defaults = {'mass':2, 'radius':10}
for key, val in defaults.items():
    setattr(particle, key, np.full(num_particles, val))

In [None]:
print(part)

In [None]:
part.asdict()

In [None]:
part['nope']

In [None]:
d = {'hi':'there'}
[*d.keys(), *d.keys()]

In [None]:
l = np.arange(10)
l.__str__()

In [None]:
# [x for x in Particles]
vars(Particles)

l = []
for key, val in Particles.__dataclass_fields__.items():
    # print(type(val.type))
    l.append(val.type)
    # print(val.metadata['grp'])

# nl
# np.arange(10, dtype=l[0])

{key:val.type for key, val in Particles.__dataclass_fields__.items() if val.metadata['grp'] == 'invariant'}

In [None]:
x = np.array(-2)
x
x.astype('uint')
x.

In [None]:
from dataclasses import make_dataclass

Position = make_dataclass('Position', ['name', 'lat', 'lon'])
here = Position('here', 1.2, -5.6)
here

In [None]:
@Timer('new')
def f(x):
    return 2*x

f(3)

In [None]:
# ! pip install codetiming
# from codetiming import Timer

# t = Timer()

@Timer('hi')
def f(x):
    return 2*x

f(4)

# with Timer('hi'):


In [None]:
num_part = 1000
rng = pick_jnp(num_part)#, force_jax=True)#, force_numpy=True)

sh = (200, 300)
x = rng.uniform(shape=sh)

In [None]:
from abc import ABC, abstractmethod

class BaseClass(ABC):
    def __getitem__(self, key):
        return getattr(self,key)

    def __setitem__(self, key, val):
        setattr(self, key, val)

def check_pos(val, nm=''):
    assert jnp.all(val >= 0.0); f"all {nm} should be >= 0; got {val}"

@dataclass
class Particles(BaseClass):
    dim: int; num

    def __init__(self, dim=3, num=1, mass=1.0, radius=2.0, gamma=0.5):
        self.dim = round(dim)
        check_pos(self.dim, 'dim')
        self.num = round(num)
        check_pos(self.num, 'num')

        # scalars
        for key, val in {'mass':mass, 'radius':radius, 'gamma':gamma, 'collisions':0}.items():
            arr = arrify(val).ravel()
            self[key] = put(arr[-1]*jnp.ones([num, 1]), vals=arr, idx=(slice(len(arr)),0))
            check_pos(self[key], key)
        assert jnp.all(self.gamma <= 1.0); f"all gamma should be <= 1; got {self.gamma}"

        # vectors
        for key, d in {'pos':[dim], 'vel':[dim], 'rot':[dim,dim], 'spin':[dim,dim]}.items():
            self[key] = jnp.ones([num] + d) * np.PINF

        assert ((0.0 <= self.gamma) & (self.gamma <= 1.0)).all(); f"all gamma should be btw 0 & 1; got {self.gamma}"
        self.beta = 2 * jnp.arctan(self.gamma)
        self.moment_of_interia = self.mass * (self.gamma * self.radius) ** 2
        g = 1 + self.gamma ** 2
        a = (1 - self.gamma ** 2) / g
        b = 2 * self.gamma / g
        # self.no_slip_matrix = jnp.array([[-1, 0, 0],
        #                                  [0, a, -b],
        #                                  [0, -b, -a],
        #                                  ])

    # def get_no_slip_matrix = 

    # def rand_

num_part = 3
rng = pick_jnp(num_part, force_jax=True)#, force_numpy=True)
part = Particles(dim=3, num=num_part, mass=[2.0, 4.0], radius=[1.0, 3.0])
part.radius.shape
part.collisions

In [None]:
a = rng.uniform(size=4)
a
assert ((0 < a) & (a < 1)).all()

In [None]:
type(round(7.3))

In [None]:
! pip install pydap

In [None]:
import xarray as xr
url = "http://www.esrl.noaa.gov/psd/thredds/dodsC/Datasets/noaa.ersst.v5/sst.mnmean.nc"
# drop an unnecessary variable which complicates some operations
ds = xr.open_dataset(url, drop_variables=["time_bnds"])
# will take a minute or two to complete
ds = ds.sel(time=slice("1960", "2018")).load()
ds

In [None]:
with xr.set_options(display_style="html"):
    display(ds)

In [None]:
ds.sst

In [None]:
num_part = 1000
rng = pick_jnp(num_part)#, force_jax=True)#, force_numpy=True)



r = rng.uniform(size=(2,3))
print(type(r), r)

In [None]:
jax.config.values

In [None]:
seed = 42
num_part = 1000

nx = np.random.uniform(size=num_part)
np_time = %timeit -o -q add_outer(nx, -nx)
np_time = round(np_time.best * 1000000)

import jax.numpy as jnp
jx = jnp.array(nx)
jax_time = %timeit -o -q add_outer(jx, -jx).block_until_ready()
jax_time = round(jax_time.best * 1000000)

if np_time <= jax_time:
    print(f"For {num_part} particles, numpy seems to be faster than jax ({np_time} vs {jax_time} µs).  So, I'll import numpy as jnp.")
    import numpy as jnp
    rng = np.random.RandomState(seed)
else:
    print(f"For {num_part} particles, numpy seems to be slower than jax ({np_time} vs {jax_time} µs).  So, I'll import jax.numpy as jnp.")
    import jax.numpy as jnp
    rng = jax_RandomState(seed)
type(jnp.arange(10))






In [None]:
import numpy as jnp
print(jnp)

def f():
    global jnp
    import jax.numpy as jnp

f()
print(jnp)

In [None]:
np.random.uniform?

In [None]:
np_time.best

In [None]:
# t = 7

def f(a=t):
    print(a)

# f(5)

In [None]:
x = rng['jax'].uniform(size=num_part)
%timeit x.reshape(-1,1)

x = rng['jax'].uniform(size=num_part)
%timeit x[:,jnp.newaxis]
x.shape
# x = 

In [None]:
x = rng['np'].uniform(size=num_part)
%timeit x.reshape(-1,1)

x = rng['jax'].uniform(size=num_part)
%timeit x[:,jnp.newaxis]
x.shape
# x = 

In [None]:
with Timer('hi'):
    x.reshape(-1,1) - x.reshape(1,-1)
print(x)

In [None]:
jnp.subtract.outer(x, x)

In [None]:
Timer.timers

In [None]:
 timeit.timeit('"-".join(str(n) for n in range(100))', number=10000)

In [None]:
# %timeit x.reshape(-1,1) - x.reshape(1,-1)
# %timeit np.subtract(x.reshape(-1,1), x.reshape(1,-1))

# print(x.shape)
# with Timer('np'):
#     for _ in range(reps):
#         x@x.T
#         # 1+2

x = rng['jax'].uniform(size=num_part)
# %timeit (x.reshape(-1,1) - x.reshape(1,-1)).block_until_ready()
jnp.subtract(x.reshape(-1,1), x.reshape(1,-1))

%timeit jnp.subtract(x.reshape(-1,1), x.reshape(1,-1)).block_until_ready()
# with Timer('jax'):
#     for _ in range(reps):
#         (x@x.T).block_until_ready()
# print(Timer.timers)

# T = {}
# for mod, r in rng.items():


#     if mod == 'np':
#         f = lambda : x@x.T
#     elif mod == 'jax':
#         f = lambda : (x@x.T).block_until_ready()


    
#     T[mod] = round(timeit.timeit("f()", setup="from __main__ import f", number=reps) / reps * 1000000, 3)
    
#     # timeit.timeit(f(), setup ) / reps * 1000
# if T['np'] < T['jax']:
#     print(f"For {num_part} particles, numpy seems to be faster ({T['np']} µs) than jax ({T['jax']} µs).  So, I'll use numpy.")
# else:
#     print(f"For {num_part} particles, numpy seems to be slower ({T['np']} µs) than jax ({T['jax']} µs).  So, I'll use jax.")

# # xnp = rng['np'].uniform(num_part)



In [None]:
x = 7
f = lambda : x
f()

In [None]:


arr = np.random.rand(2,3)
print(arrayify(arr).shape)
print(arrayify((3,4,5)).shape)

# %timeit jarrayify(3)
arrify(3)
# %timeit arrayify2(3)

In [None]:
x@x.T

In [None]:
import numpy as np

    # return arr

# x = jnp.ones((5, 6))
x = np.ones((5, 6))
y = np.ones((5, 6)) * 5
# x = mulat(x, (1,2), 100)
x = minimum(x)
x
# %timeit mulat(x, (1,2), 100)

# x2 = np.ones((5, 6))
# x2 = mulat2(x2, (1,2), 100)
# %timeit mulat2(x, (1,2), 100)
# # addat(x)
# (x == x2).all()
# print(x.shape, type(x))


# x.at([2,4])

In [None]:
x = jax.numpy.ones((5, 6))
y = x.at[(1,2)]
# x
y.

In [None]:
x = jax.numpy.ones((5, 6))
x = np.ones((5, 6))
x = jax.ops.index_mul(x, jax.ops.index[2:4, 3:], 6.)
x

In [None]:
jax.ops.index_mul??

In [None]:

N = 10
x = rng['jax'].normal(size=(N, N)).astype('float32')
x.at().set(1.0)

In [None]:
x = np.arange(10)
# x.add
np.add
np.add(7)
# np.add.at(x, [3,5], 7)
# x
# np.put(x, [2,5], [100,200])
# np.add.at?

# # np.at?

# N = 10
# x = rng['jax'].normal(size=(N, N)).astype('float32')
# # x
# # jnp.put(x, [2,5], [100,200])
# # x
# jnp.add?
# #.at(x, [2,5], [100,200])

In [None]:
Ns = np.arange(1, 202, 50)
nts = []
jts = []
for N in Ns:
    print()
    print(N)
    print('numpy')
    x = rng['np'] .normal(size=(N, N)).astype('float32')
    np.dot(x, x.T)
    nt = %timeit -o np.dot(x, x.T)
    nts.append(np.mean(nt.all_runs))

    print('Jax')
    y = rng['jax'].normal(size=(N, N)).astype('float32')
    jnp.dot(y, y.T).block_until_ready()
    jt = %timeit -o jnp.dot(y, y.T).block_until_ready()
    jts.append(np.mean(jt.all_runs))


In [None]:
np.mean(nt.all_runs)

In [None]:
import numpy as np
# import jax.numpy as np
x = np.arange(10)
# x = x.at[1].set(5)
# x

type(x)

class arr(np.ndarray):
    def as

# # A.hi()

# def hi(self):
#     print('hi')

# # setattr(np.ndarray, "hi", hi)
# # setattr(A, "hi", hi)
# # A.hi()

# def foo(self):
#     print('hello world!')
# setattr(A, 'foo', foo)
# A.foo()

In [None]:
import jax

w = jax.numpy.ones(5)
jax.numpy.lax_numpy._IndexUpdateHelper??

In [None]:
import jax
import jax.numpy as np
import numpy as np

def extend_ONB(X):
    X = np.asarray(X, dtype=float).copy()
    sh = X.shape
    # must be 2D
    if len(sh) == 1:
        X = X[:,np.newaxis]
        sh = X.shape
    elif len(sh) > 2:
        raise Exception(f"array made be 2D; given shape {sh}")
    
    # make n_rows >= n_cols
    if sh[0] < sh[1]:
        X = X.T
    dim, n = X.shape
    # Extend to ONB using QR-factorization
    Q, _ = np.linalg.qr(X, mode='complete')
    # QR may flip signs on specified column vectors.  Check and flip back.
    # For first n columns, multiply top row of X and Q, then multiply Q by resulting signs
    # Q[:, :n] *= np.sign(X[0] * Q[0, :n])  
    Q = put(Q, [slice(), slice(:n)] *= np.sign(X[0] * Q[0, :n])  
    return Q

def perturb(v):
    import scipy.stats as stats
    dim = len(v)
    assert dim >= 2, "len(v) must be >= 2"

    T = extend_ONB(v)

    R1 = np.eye(dim)
    theta = rng.uniform(low=tol, high=2*np.pi / 100)
    c, s = np.cos(theta), np.sin(theta)
    R1[:2, :2] = [[c, -s], [s,c]]

    R2 = np.eye(dim)
    if dim == 2:
        R2[1, 1] = rng.choice([-1,1])
    else:
        R2[1:, 1:] = stats.special_ortho_group(dim-1).rvs()

    A = T @ R2 @ R1 @ T.T
    w = A @ v
    return w

rng = jax_RandomState(seed)

dim = 2
tol = 1e-4
v = rng.uniform(size=dim)

print(v.round(3))
# %timeit w = perturb(v)

l = []
for i in range(25):
    # print(i)
    w = perturb(v)
    c = v @ w / np.sqrt((v@v) * (w@w))
    l.append(c)
# w = perturb(v)
# w = perturb(v)

print(l)


In [None]:
import jax
jax.numpy.linalg.qr
jax.numpy.asarray

In [None]:
import numpy as np
import jax
dim = 3
box_dims = np.arange(3) + 10.0
box_dims


class Particles():
    def __init__(self, num=1, mass=1.0, radius=1.0):
        self.num = num
        mass = np.asarry(mass).ravel()
        sh = len(mass)
        M = mass[-1]
        self.mass = jax.numpy.full([self.num], fill_value=M, dtype='float')
        self.mass
        self.radius = jax.numpy.broadcast_to(radius, [self.num])

part = Particles(num=6, mass=[2.0, 3.0], radius=3.0)
part.mass


In [None]:
w = np.array(5).ravel()

w.shape