<a href="https://colab.research.google.com/github/drscook/Billiards2020_with_Jax/blob/master/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import google.colab
import getpass

preferred_path = f"My Drive/Billiards2020"   # your preferred google drive path
try:
    git['repo']
except:
    git = {'repo' : 'Billiards2020_with_Jax',
           'user' : 'drscook',
           'email': 'scook@tarleton.edu'}

for var, secure in [['user', False], ['token', True], ['email', False]]:
    prompt = f'Enter your github {var} '
    try:
        git[var]
        prompt += 'or press enter to use existing '
        if secure:
            new_var = getpass.getpass(prompt)
        else:
            prompt += f'"{git[var]}" '
            new_var = input(prompt)
        if new_var != "":
            git[var] = new_var
    except:
        if secure:
            git[var] = getpass.getpass(prompt)
        else:
            git[var] = input(prompt)
git['url'] = f"https://{git['token']}@github.com/{git['user']}/{git['repo']}.git"
! git config --global user.name "{git['user']}"    # set git info
! git config --global user.email "{git['email']}"  # set git info


# Mount your Google Drive to this Colab instance
root_path = "/content/drive/"
google.colab.drive.mount(root_path)

git['path'] = os.path.join(root_path, preferred_path, git['repo'])  # path to folder for your local repo
os.makedirs(git['path'], exist_ok=True)  # create path if necessary
%cd "{git['path']}"

def push(msg=f"enter commit message here"):
    %cd "{git['path']}"
    ! git add .  # Add any new files
    ! git commit -a -m "{msg}"  # Commit changes
    ! git push  # push changes from your local to your orign

def pull():
    %cd "{git['path']}"
    ! git pull  # pull changes from cook origin to cook local

def clone():
    %cd "{git['path']}"
    ! git clone "{git['url']}" .  # clone cook origin to cook local (if necessary)

clone()
pull()

In [None]:
pull()

In [None]:
push()

In [None]:
%run -i helpers.ipynb

In [None]:
@dataclass
class Particles():
    dim: int        = field(default=3,   metadata={'grp':'invariants'})
    num: int        = field(default=1,   metadata={'grp':'invariants'})
    collisions: int = field(default=0,   metadata={'grp':'scalars'})
    mass: float     = field(default=1.0, metadata={'grp':'scalars'})
    radius: float   = field(default=1.0, metadata={'grp':'scalars'})
    gamma: float    = field(default=0.5, metadata={'grp':'scalars'})
    pos: float      = field(default=0.0, metadata={'grp':'vectors'})
    vel: float      = field(default=0.0, metadata={'grp':'vectors'})
    rot: float      = field(default=0.0, metadata={'grp':'vectors'})
    spin: float     = field(default=0.0, metadata={'grp':'vectors'})

    def __post_init__(self):
        for key, val in Particles.__dataclass_fields__.items():
            grp = val.metadata['grp']
            new = {key: val.type}
            try:
                self[grp].update(new)
            except KeyError:
                self[grp] = new

        # for key in {**self.invariants, **self.scalars}.keys():
        for key in [*self.invariants.keys(), *self.scalars.keys()]:
            assert (jnp.asarray(self[key]) >= 0).all(), f"{key} must be >= 0; got {self[key]}"

        for key in ['gamma']:
            assert (jnp.asarray(self[key]) <= 1).all(), f"{key} must be <= 1; got {self[key]}"

        for key, dtype in self.scalars.items():
            # if self[key] has fewer than self.num values, repeat the last value
            vals = arrify(self[key]).ravel()[:self.num]   # convert to array, flatten, remove values beyond number of particles
            temp = jnp.full(self.num, vals[-1], dtype)   # temp array filled with last value of vals
            self[key] = put(temp, vals, idx=slice(len(vals)))   # stick vals into start of temp; leave trailing repeats

        for key, dtype in self.vectors.items():
            self[key] = jnp.full((self.num, self.dim), np.PINF, dtype)

        self.beta = 2 * jnp.arctan(self.gamma)
        self.moment_of_interia = self.mass * (self.gamma * self.radius) ** 2
        # self g = 1 + self.gamma ** 2
        # a = (1 - self.gamma ** 2) / g
        # b = 2 * self.gamma / g


    def __str__(self):
        s = "Particles: "
        for key in self.invariants:
            s += f'{key}={self[key]}, '
        
        for key in self.scalars:
            # Truncate with "..." when values repeat
            val = self[key]
            idx = [True, *(jnp.diff(val) != 0)]   # False where equal to prior value
            s += f'{key}={val[idx]}, '
            if not all(idx):
                s = s[:-3] + " ...], "   # if truncated, insert "..."
        return s[:-2]

    def __repr__(self):
        s = self.__str__()
        for key in self.vectors:
            s += f'\n{key}\n{self[key]}'
        return s

rng = numpy_or_jax(force_jax=True)#, force_numpy=True)
part = Particles(dim=2, num=4, collisions=23.3, mass=[3,6,9,12], gamma=0.75)
display(part)
print(part)

In [None]:
part.vectors

In [None]:
git_push()

In [None]:
code2018 = """
        U1_out = ((U1_in - d / (m1 * g1**2) * Lambda_nu(U1_in, nu))
               + (-d / (m1 * r1 * g1**2)) * E_nu(v1_in, nu)
               + (-r2 / r1) * (d / (m1 * g1**2)) * Lambda_nu(U2_in, nu)
               + d / (m1 * r1 * g1**2) * E_nu(v2_in, nu))

        v1_out = ((-r1 * d / m1) * Gamma_nu(U1_in, nu)
               + (v1_in - 2 * m2 / M * Pi_nu(v1_in, nu) - (d / m1) * Pi(v1_in, nu))
               + (-r2 * d / m1) * Gamma_nu(U2_in, nu)
               + (2 * m2 / M) * Pi_nu(v2_in, nu) + (d / m1) * Pi(v2_in, nu))

        U2_out = ((-r1 / r2) * (d / (m2 * g2**2)) * Lambda_nu(U1_in, nu)
               + (-d / (m2 * r2 * g2**2)) * E_nu(v1_in, nu)
               + (U2_in - (d / (m2 * g2**2)) * Lambda_nu(U2_in, nu))
               + (d / (m2 * r2 * g2**2)) * E_nu(v2_in, nu))

        v2_out = ((r1 * d / m2) * Gamma_nu(U1_in, nu)
               + (2 * m1 / M) * Pi_nu(v1_in, nu) + (d / m2) * Pi(v1_in, nu)
               + (r2 * d / m2) * Gamma_nu(U2_in, nu)
               + v2_in - (2 * m1 / M) * Pi_nu(v2_in, nu) - (d / m2) * Pi(v2_in, nu))
"""


from collections import OrderedDict
d = OrderedDict()



for i in range(2):
    d[f'Lambda_nu(U{i+1}_in, nu)'] = f'q[{i}].lambda_nu'
    d[f'Gamma_nu(U{i+1}_in, nu)'] = f'q[{i}].gamma_nu'
    d[f'E_nu(v{i+1}_in, nu)'] = f'q[{i}].e_nu'
    d[f'Pi_nu(v{i+1}_in, nu)'] = f'q[{i}].pi_nu'
    d[f'Pi(v{i+1}_in, nu)'] = f'q[{i}].pi'
    d[f'U{i+1}_out'] = f'q[{i}].u_out'
    d[f'v{i+1}_out'] = f'q[{i}].v_out'
    d[f'U{i+1}_in'] = f'q[{i}].u'
    d[f'v{i+1}_in'] = f'q[{i}].v'
    d[f'm{i+1}'] = f'q[{i}].m'
    d[f'r{i+1}'] = f'q[{i}].r'
    d[f'g{i+1}'] = f'q[{i}].g'

d

new_code = code2018
for old, new in d.items():
    new_code = new_code.replace(old, new)

print(new_code)
d

In [None]:
a = [1,2,3]
b = [-2,3,4]
nu = [3,5,7]
w = No_Slip_Quantities(a, b, nu)
# print(w.get_outer())
# w.get_outer('subtract')
# w.outer
w.lambda_nu

In [None]:
a = np.arange(120).reshape(2,3,4,5)
np.transpose(a, (1,3,0,2)).shape
q = np.roll(np.arange(6), 3)
q

d = 6
(np.arange(d) + int(d/2)) % d


In [None]:
w.

In [None]:
d = {'hi':'there'}
d['yo']

In [None]:
jnp.multiply

In [None]:
X = jnp.array([1,2])
Y = X.T
print(X, '\n\n',Y)

outer(X, Y, jnp.subtract)

In [None]:
num_part=1000
# rng = numpy_or_jax(num_part)
rng = numpy_or_jax(force_jax=True)

In [None]:
# git_push()

In [None]:
jnp.linalg.

In [None]:
num_particles = 20
defaults = {'mass':2, 'radius':10}
for key, val in defaults.items():
    particle[key] = np.full(num_particles, val)

In [None]:
num_particles = 20
defaults = {'mass':2, 'radius':10}
for key, val in defaults.items():
    setattr(particle, key, np.full(num_particles, val))

In [None]:
print(part)

In [None]:
part.asdict()

In [None]:
part['nope']

In [None]:
d = {'hi':'there'}
[*d.keys(), *d.keys()]

In [None]:
l = np.arange(10)
l.__str__()

In [None]:
# [x for x in Particles]
vars(Particles)

l = []
for key, val in Particles.__dataclass_fields__.items():
    # print(type(val.type))
    l.append(val.type)
    # print(val.metadata['grp'])

# nl
# np.arange(10, dtype=l[0])

{key:val.type for key, val in Particles.__dataclass_fields__.items() if val.metadata['grp'] == 'invariant'}

In [None]:
x = np.array(-2)
x
x.astype('uint')
x.

In [None]:
from dataclasses import make_dataclass

Position = make_dataclass('Position', ['name', 'lat', 'lon'])
here = Position('here', 1.2, -5.6)
here

In [None]:
@Timer('new')
def f(x):
    return 2*x

f(3)

In [None]:
# ! pip install codetiming
# from codetiming import Timer

# t = Timer()

@Timer('hi')
def f(x):
    return 2*x

f(4)

# with Timer('hi'):


In [None]:
num_part = 1000
rng = pick_jnp(num_part)#, force_jax=True)#, force_numpy=True)

sh = (200, 300)
x = rng.uniform(shape=sh)

In [None]:
from abc import ABC, abstractmethod

class BaseClass(ABC):
    def __getitem__(self, key):
        return getattr(self,key)

    def __setitem__(self, key, val):
        setattr(self, key, val)

def check_pos(val, nm=''):
    assert jnp.all(val >= 0.0); f"all {nm} should be >= 0; got {val}"

@dataclass
class Particles(BaseClass):
    dim: int; num

    def __init__(self, dim=3, num=1, mass=1.0, radius=2.0, gamma=0.5):
        self.dim = round(dim)
        check_pos(self.dim, 'dim')
        self.num = round(num)
        check_pos(self.num, 'num')

        # scalars
        for key, val in {'mass':mass, 'radius':radius, 'gamma':gamma, 'collisions':0}.items():
            arr = arrify(val).ravel()
            self[key] = put(arr[-1]*jnp.ones([num, 1]), vals=arr, idx=(slice(len(arr)),0))
            check_pos(self[key], key)
        assert jnp.all(self.gamma <= 1.0); f"all gamma should be <= 1; got {self.gamma}"

        # vectors
        for key, d in {'pos':[dim], 'vel':[dim], 'rot':[dim,dim], 'spin':[dim,dim]}.items():
            self[key] = jnp.ones([num] + d) * np.PINF

        assert ((0.0 <= self.gamma) & (self.gamma <= 1.0)).all(); f"all gamma should be btw 0 & 1; got {self.gamma}"
        self.beta = 2 * jnp.arctan(self.gamma)
        self.moment_of_interia = self.mass * (self.gamma * self.radius) ** 2
        g = 1 + self.gamma ** 2
        a = (1 - self.gamma ** 2) / g
        b = 2 * self.gamma / g
        # self.no_slip_matrix = jnp.array([[-1, 0, 0],
        #                                  [0, a, -b],
        #                                  [0, -b, -a],
        #                                  ])

    # def get_no_slip_matrix = 

    # def rand_

num_part = 3
rng = pick_jnp(num_part, force_jax=True)#, force_numpy=True)
part = Particles(dim=3, num=num_part, mass=[2.0, 4.0], radius=[1.0, 3.0])
part.radius.shape
part.collisions

In [None]:
! pip install pydap

In [None]:
import xarray as xr
url = "http://www.esrl.noaa.gov/psd/thredds/dodsC/Datasets/noaa.ersst.v5/sst.mnmean.nc"
# drop an unnecessary variable which complicates some operations
ds = xr.open_dataset(url, drop_variables=["time_bnds"])
# will take a minute or two to complete
ds = ds.sel(time=slice("1960", "2018")).load()
ds

In [None]:
with xr.set_options(display_style="html"):
    display(ds)

In [None]:
ds.sst

In [None]:
num_part = 1000
rng = pick_jnp(num_part)#, force_jax=True)#, force_numpy=True)



r = rng.uniform(size=(2,3))
print(type(r), r)

In [None]:
# %timeit x.reshape(-1,1) - x.reshape(1,-1)
# %timeit np.subtract(x.reshape(-1,1), x.reshape(1,-1))

# print(x.shape)
# with Timer('np'):
#     for _ in range(reps):
#         x@x.T
#         # 1+2

x = rng['jax'].uniform(size=num_part)
# %timeit (x.reshape(-1,1) - x.reshape(1,-1)).block_until_ready()
jnp.subtract(x.reshape(-1,1), x.reshape(1,-1))

%timeit jnp.subtract(x.reshape(-1,1), x.reshape(1,-1)).block_until_ready()
# with Timer('jax'):
#     for _ in range(reps):
#         (x@x.T).block_until_ready()
# print(Timer.timers)

# T = {}
# for mod, r in rng.items():


#     if mod == 'np':
#         f = lambda : x@x.T
#     elif mod == 'jax':
#         f = lambda : (x@x.T).block_until_ready()


    
#     T[mod] = round(timeit.timeit("f()", setup="from __main__ import f", number=reps) / reps * 1000000, 3)
    
#     # timeit.timeit(f(), setup ) / reps * 1000
# if T['np'] < T['jax']:
#     print(f"For {num_part} particles, numpy seems to be faster ({T['np']} µs) than jax ({T['jax']} µs).  So, I'll use numpy.")
# else:
#     print(f"For {num_part} particles, numpy seems to be slower ({T['np']} µs) than jax ({T['jax']} µs).  So, I'll use jax.")

# # xnp = rng['np'].uniform(num_part)



In [None]:
import jax
import jax.numpy as np
import numpy as np

def extend_ONB(X):
    X = np.asarray(X, dtype=float).copy()
    sh = X.shape
    # must be 2D
    if len(sh) == 1:
        X = X[:,np.newaxis]
        sh = X.shape
    elif len(sh) > 2:
        raise Exception(f"array made be 2D; given shape {sh}")
    
    # make n_rows >= n_cols
    if sh[0] < sh[1]:
        X = X.T
    dim, n = X.shape
    # Extend to ONB using QR-factorization
    Q, _ = np.linalg.qr(X, mode='complete')
    # QR may flip signs on specified column vectors.  Check and flip back.
    # For first n columns, multiply top row of X and Q, then multiply Q by resulting signs
    # Q[:, :n] *= np.sign(X[0] * Q[0, :n])  
    Q = put(Q, [slice(), slice(:n)] *= np.sign(X[0] * Q[0, :n])  
    return Q

def perturb(v):
    import scipy.stats as stats
    dim = len(v)
    assert dim >= 2, "len(v) must be >= 2"

    T = extend_ONB(v)

    R1 = np.eye(dim)
    theta = rng.uniform(low=tol, high=2*np.pi / 100)
    c, s = np.cos(theta), np.sin(theta)
    R1[:2, :2] = [[c, -s], [s,c]]

    R2 = np.eye(dim)
    if dim == 2:
        R2[1, 1] = rng.choice([-1,1])
    else:
        R2[1:, 1:] = stats.special_ortho_group(dim-1).rvs()

    A = T @ R2 @ R1 @ T.T
    w = A @ v
    return w

rng = jax_RandomState(seed)

dim = 2
tol = 1e-4
v = rng.uniform(size=dim)

print(v.round(3))
# %timeit w = perturb(v)

l = []
for i in range(25):
    # print(i)
    w = perturb(v)
    c = v @ w / np.sqrt((v@v) * (w@w))
    l.append(c)
# w = perturb(v)
# w = perturb(v)

print(l)


In [None]:
import jax
jax.numpy.linalg.qr
jax.numpy.asarray

In [None]:
import numpy as np
import jax
dim = 3
box_dims = np.arange(3) + 10.0
box_dims


class Particles():
    def __init__(self, num=1, mass=1.0, radius=1.0):
        self.num = num
        mass = np.asarry(mass).ravel()
        sh = len(mass)
        M = mass[-1]
        self.mass = jax.numpy.full([self.num], fill_value=M, dtype='float')
        self.mass
        self.radius = jax.numpy.broadcast_to(radius, [self.num])

part = Particles(num=6, mass=[2.0, 3.0], radius=3.0)
part.mass


In [None]:
w = np.array(5).ravel()

w.shape